From 98e8128ab06350bcad71ae00ce748a0ccb8344f2 Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Mon, 23 Feb 2026 15:10:09 +0100 Subject: [PATCH 01/15] cleaned up code; added train/predict scripts --- __init__.py | 0 fiora/GNN/AtomFeatureEncoder.py | 117 +- fiora/GNN/BondFeatureEncoder.py | 71 +- fiora/GNN/CovariateFeatureEncoder.py | 112 +- fiora/GNN/Datasets.py | 219 +- fiora/GNN/EdgePropertyPredictor.py | 133 +- fiora/GNN/FeatureEmbedding.py | 90 +- fiora/GNN/FioraModel.py | 238 +- fiora/GNN/GNN.py | 83 +- fiora/GNN/GNNCustomLayers.py | 5 +- fiora/GNN/GraphPropertyPredictor.py | 61 +- fiora/GNN/Losses.py | 82 +- fiora/GNN/MLP.py | 18 +- fiora/GNN/MLPEdgeClassifier.py | 12 +- fiora/GNN/SpectralTrainer.py | 319 +- fiora/GNN/Trainer.py | 102 +- fiora/IO/LibraryLoader.py | 29 +- fiora/IO/cfmReader.py | 38 +- fiora/IO/fraggraphReader.py | 23 +- fiora/IO/mgfReader.py | 50 +- fiora/IO/mgfWriter.py | 20 +- fiora/IO/molReader.py | 10 +- fiora/IO/mspReader.py | 88 +- fiora/IO/mspWriter.py | 31 +- fiora/IO/mspredReader.py | 23 +- fiora/IO/mspredWriter.py | 80 +- fiora/MOL/FragmentationTree.py | 291 +- fiora/MOL/Metabolite.py | 661 +++- fiora/MOL/MetaboliteDatasetStatistics.py | 58 +- fiora/MOL/MetaboliteIndex.py | 44 +- fiora/MOL/collision_energy.py | 48 +- fiora/MOL/constants.py | 107 +- fiora/MOL/mol_graph.py | 100 +- fiora/MS/SimulationFramework.py | 345 +- fiora/MS/ms_utility.py | 62 +- fiora/MS/spectral_scores.py | 120 +- fiora/cli/__init__.py | 1 + fiora/cli/predict.py | 362 ++ fiora/cli/train.py | 539 +++ fiora/visualization/define_colors.py | 160 +- fiora/visualization/inspect_mgf_file.py | 28 +- fiora/visualization/plot_spectrum.py | 42 +- fiora/visualization/spectrum_visualizer.py | 111 +- lib_loader/casmi16_loader.ipynb | 403 +- lib_loader/casmi22_loader.ipynb | 210 +- lib_loader/gnps_library_loader.ipynb | 1862 ++++++--- lib_loader/ms_dial_loader.ipynb | 289 +- lib_loader/msnlib_loader.ipynb | 435 +- lib_loader/nist_library_loader.ipynb | 571 ++- notebooks/break_tendency.ipynb | 1849 ++++++--- notebooks/grid_search.ipynb | 478 ++- notebooks/grid_stats.ipynb | 146 +- notebooks/info_graphs.ipynb | 358 +- notebooks/live_predict.ipynb | 72 +- notebooks/sandbox.ipynb | 1051 +++-- notebooks/test_model.ipynb | 4147 ++++++++++++++------ notebooks/train_model.ipynb | 959 +++-- pyproject.toml | 8 +- scripts/fiora-predict | 199 +- scripts/fiora-train | 6 + tests/test_fiora_predict.py | 67 +- 61 files changed, 12967 insertions(+), 5276 deletions(-) delete mode 100644 __init__.py create mode 100644 fiora/cli/__init__.py create mode 100644 fiora/cli/predict.py create mode 100644 fiora/cli/train.py create mode 100644 scripts/fiora-train diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/fiora/GNN/AtomFeatureEncoder.py b/fiora/GNN/AtomFeatureEncoder.py index 284678b..77daf10 100644 --- a/fiora/GNN/AtomFeatureEncoder.py +++ b/fiora/GNN/AtomFeatureEncoder.py @@ -1,61 +1,81 @@ -import torch -import numpy as np +import torch from rdkit import Chem from typing import Literal from fiora.MOL.constants import ORDERED_ELEMENT_LIST - - class AtomFeatureEncoder: def __init__(self, feature_list=["symbol", "num_hydrogen", "ring_type"]): self.encoding_dim = 0 self.sets = { - "symbol": ORDERED_ELEMENT_LIST, #OTHERS: Au, Se, Si #standard list {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"}, - "num_hydrogen": [0, 1, 2, 3], #OTHERS: 5, 6, 7, 8}, + "symbol": ORDERED_ELEMENT_LIST, # OTHERS: Au, Se, Si #standard list {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"}, + "num_hydrogen": [0, 1, 2, 3], # OTHERS: 5, 6, 7, 8}, "ring_type": ["no-ring", "small-ring", "5-cycle", "6-cycle", "large-ring"], "hybridization": ["SP", "SP2", "SP3", "SP3D2"], - "valence_electrons": [1,2,3,4,5,6,7,8], - "oxidation_number": [1,2,3,4,5,6,7,8,9], + "valence_electrons": [1, 2, 3, 4, 5, 6, 7, 8], + "oxidation_number": [1, 2, 3, 4, 5, 6, 7, 8, 9], } self.feature_list = feature_list - self.reduced_features = ["symbol", "num_hydrogen", "hybridization"] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS) - + self.reduced_features = [ + "symbol", + "num_hydrogen", + "hybridization", + ] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS) + self.one_hot_mapper = {} self.number_mapper = {} self.feature_numbers = {} for feature in self.feature_list: variables = self.sets[feature] num_variables = len(variables) - self.one_hot_mapper[feature] = dict(zip(variables, range(self.encoding_dim, num_variables + self.encoding_dim))) + self.one_hot_mapper[feature] = dict( + zip( + variables, + range(self.encoding_dim, num_variables + self.encoding_dim), + ) + ) self.number_mapper[feature] = dict(zip(variables, range(0, num_variables))) self.encoding_dim += num_variables if feature in self.reduced_features: self.encoding_dim += 1 num_variables += 1 self.feature_numbers[feature] = num_variables - - def encode(self, G, encoder_type: Literal['one_hot', 'number']): - - if encoder_type == 'one_hot': - feature_matrix = torch.zeros(G.number_of_nodes(), self.encoding_dim, dtype=torch.float32) + + def encode(self, G, encoder_type: Literal["one_hot", "number"]): + + if encoder_type == "one_hot": + feature_matrix = torch.zeros( + G.number_of_nodes(), self.encoding_dim, dtype=torch.float32 + ) for i in range(G.number_of_nodes()): - atom = G.nodes()[i]['atom'] - - if 'symbol' in self.feature_list: - if not atom.GetSymbol() in self.sets['symbol']: - feature_matrix[i][self.one_hot_mapper['symbol'][list(self.sets['symbol'])[-1]] + 1] = 1.0 + atom = G.nodes()[i]["atom"] + + if "symbol" in self.feature_list: + if atom.GetSymbol() not in self.sets["symbol"]: + feature_matrix[i][ + self.one_hot_mapper["symbol"][list(self.sets["symbol"])[-1]] + + 1 + ] = 1.0 else: - feature_matrix[i][self.one_hot_mapper['symbol'][atom.GetSymbol()]] = 1.0 + feature_matrix[i][ + self.one_hot_mapper["symbol"][atom.GetSymbol()] + ] = 1.0 - if 'num_hydrogen' in self.feature_list: + if "num_hydrogen" in self.feature_list: value = atom.GetTotalNumHs() if value in self.sets["num_hydrogen"]: - feature_matrix[i][self.one_hot_mapper['num_hydrogen'][atom.GetTotalNumHs()]] = 1.0 + feature_matrix[i][ + self.one_hot_mapper["num_hydrogen"][atom.GetTotalNumHs()] + ] = 1.0 else: - feature_matrix[i][self.one_hot_mapper['num_hydrogen'][list(self.sets['num_hydrogen'])[-1]] + 1] = 1.0 - if 'ring_type' in self.feature_list: + feature_matrix[i][ + self.one_hot_mapper["num_hydrogen"][ + list(self.sets["num_hydrogen"])[-1] + ] + + 1 + ] = 1.0 + if "ring_type" in self.feature_list: if not atom.IsInRing(): ring_type = "no-ring" elif atom.IsInRingSize(7): @@ -66,38 +86,49 @@ def encode(self, G, encoder_type: Literal['one_hot', 'number']): ring_type = "5-cycle" else: ring_type = "small-ring" - feature_matrix[i][self.one_hot_mapper['ring_type'][ring_type]] = 1.0 - if 'hybridization' in self.feature_list: + feature_matrix[i][self.one_hot_mapper["ring_type"][ring_type]] = 1.0 + if "hybridization" in self.feature_list: orbi = atom.GetHybridization().name - if orbi in self.sets['hybridization']: - feature_matrix[i][self.one_hot_mapper['hybridization'][orbi]] = 1.0 + if orbi in self.sets["hybridization"]: + feature_matrix[i][ + self.one_hot_mapper["hybridization"][orbi] + ] = 1.0 else: - feature_matrix[i][self.one_hot_mapper['hybridization'][list(self.sets['hybridization'])[-1]] + 1] = 1.0 - - else: # Case: Number mapping - feature_matrix = torch.zeros(G.number_of_nodes(), len(self.feature_list), dtype=torch.int) + feature_matrix[i][ + self.one_hot_mapper["hybridization"][ + list(self.sets["hybridization"])[-1] + ] + + 1 + ] = 1.0 + + else: # Case: Number mapping + feature_matrix = torch.zeros( + G.number_of_nodes(), len(self.feature_list), dtype=torch.int + ) for i in range(G.number_of_nodes()): - atom = G.nodes()[i]['atom'] + atom = G.nodes()[i]["atom"] for j, feature in enumerate(self.feature_list): if feature == "symbol": - if atom.GetSymbol() in self.sets['symbol']: - feature_matrix[i][j] = self.number_mapper[feature][atom.GetSymbol()] + if atom.GetSymbol() in self.sets["symbol"]: + feature_matrix[i][j] = self.number_mapper[feature][ + atom.GetSymbol() + ] else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 - elif feature == 'num_hydrogen': + elif feature == "num_hydrogen": value = atom.GetTotalNumHs() if value in self.sets["num_hydrogen"]: feature_matrix[i][j] = self.number_mapper[feature][value] else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 - elif feature == 'valence_electrons': + elif feature == "valence_electrons": value = atom.GetExplicitValence() if value in self.sets["valence_electrons"]: feature_matrix[i][j] = self.number_mapper[feature][value] else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 - elif feature == 'oxidation_number': + elif feature == "oxidation_number": raise NotImplementedError() value = Chem.rdMolDescriptors.CalcOxidationNumbers(atom) if value in self.sets["oxidation_number"]: @@ -105,7 +136,7 @@ def encode(self, G, encoder_type: Literal['one_hot', 'number']): else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 - elif feature == 'ring_type': + elif feature == "ring_type": if not atom.IsInRing(): ring_type = "no-ring" elif atom.IsInRingSize(7): @@ -117,9 +148,9 @@ def encode(self, G, encoder_type: Literal['one_hot', 'number']): else: ring_type = "small-ring" feature_matrix[i][j] = self.number_mapper[feature][ring_type] - if feature == 'hybridization': + if feature == "hybridization": orbi = atom.GetHybridization().name - if orbi in self.sets['hybridization']: + if orbi in self.sets["hybridization"]: feature_matrix[i][j] = self.number_mapper[feature][orbi] else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 diff --git a/fiora/GNN/BondFeatureEncoder.py b/fiora/GNN/BondFeatureEncoder.py index 2de3a07..05aaf60 100644 --- a/fiora/GNN/BondFeatureEncoder.py +++ b/fiora/GNN/BondFeatureEncoder.py @@ -1,6 +1,4 @@ - -import torch -import numpy as np +import torch from typing import Literal @@ -11,35 +9,43 @@ def __init__(self, feature_list=["bond_type", "ring_type"]): self.sets = { "bond_type": ["AROMATIC", "SINGLE", "DOUBLE", "TRIPLE"], "ring_type": ["no-ring", "small-ring", "5-cycle", "6-cycle", "large-ring"], - "ring_type_binary": ["is_in_ring"], + "ring_type_binary": ["is_in_ring"], } - self.reduced_features = [] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS) + self.reduced_features = [] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS) self.one_hot_mapper = {} self.number_mapper = {} self.feature_numbers = {} for feature in self.feature_list: variables = self.sets[feature] num_variables = len(variables) - self.one_hot_mapper[feature] = dict(zip(variables, range(self.encoding_dim, num_variables + self.encoding_dim))) + self.one_hot_mapper[feature] = dict( + zip( + variables, + range(self.encoding_dim, num_variables + self.encoding_dim), + ) + ) self.number_mapper[feature] = dict(zip(variables, range(0, num_variables))) self.encoding_dim += num_variables - + if feature in self.reduced_features: self.encoding_dim += 1 num_variables += 1 self.feature_numbers[feature] = num_variables - - def encode(self, G, edges,encoder_type: Literal['one_hot', 'number']): - - if encoder_type == 'one_hot': - feature_matrix = torch.zeros(len(edges), self.encoding_dim, dtype=torch.float32) + def encode(self, G, edges, encoder_type: Literal["one_hot", "number"]): + + if encoder_type == "one_hot": + feature_matrix = torch.zeros( + len(edges), self.encoding_dim, dtype=torch.float32 + ) - for i, (u,v) in enumerate(edges): + for i, (u, v) in enumerate(edges): bond = G[u][v]["bond"] if "bond_type" in self.feature_list: - feature_matrix[i][self.one_hot_mapper["bond_type"][bond.GetBondType().name]] = 1.0 - if 'ring_type' in self.feature_list: + feature_matrix[i][ + self.one_hot_mapper["bond_type"][bond.GetBondType().name] + ] = 1.0 + if "ring_type" in self.feature_list: if not bond.IsInRing(): ring_type = "no-ring" elif bond.IsInRingSize(7): @@ -50,26 +56,31 @@ def encode(self, G, edges,encoder_type: Literal['one_hot', 'number']): ring_type = "5-cycle" else: ring_type = "small-ring" - feature_matrix[i][self.one_hot_mapper['ring_type'][ring_type]] = 1.0 - if 'ring_type_binary' in self.feature_list: + feature_matrix[i][self.one_hot_mapper["ring_type"][ring_type]] = 1.0 + if "ring_type_binary" in self.feature_list: if bond.IsInRing(): - feature_matrix[i][self.one_hot_mapper['ring_type_binary']["is_in_ring"]] = 1.0 + feature_matrix[i][ + self.one_hot_mapper["ring_type_binary"]["is_in_ring"] + ] = 1.0 # else case implicit = 0 - - elif encoder_type == 'number': # Case: Number mapping - feature_matrix = torch.zeros(len(edges), len(self.feature_list), dtype=torch.int) - for i, (u,v) in enumerate(edges): - bond = G[u][v]["bond"] + elif encoder_type == "number": # Case: Number mapping + feature_matrix = torch.zeros( + len(edges), len(self.feature_list), dtype=torch.int + ) + for i, (u, v) in enumerate(edges): + bond = G[u][v]["bond"] for j, feature in enumerate(self.feature_list): if feature == "bond_type": value = bond.GetBondType().name - if value in self.sets['bond_type']: + if value in self.sets["bond_type"]: feature_matrix[i][j] = self.number_mapper[feature][value] else: - raise NotImplementedError("Unknown bond type is not accounted for.") - elif feature == 'ring_type': + raise NotImplementedError( + "Unknown bond type is not accounted for." + ) + elif feature == "ring_type": if not bond.IsInRing(): ring_type = "no-ring" elif bond.IsInRingSize(7): @@ -81,7 +92,9 @@ def encode(self, G, edges,encoder_type: Literal['one_hot', 'number']): else: ring_type = "small-ring" feature_matrix[i][j] = self.number_mapper[feature][ring_type] - if feature == 'ring_type_binary': - raise NotImplementedError("Binary feature not implemented with number embedding. Use default 'ring_type' instead.") + if feature == "ring_type_binary": + raise NotImplementedError( + "Binary feature not implemented with number embedding. Use default 'ring_type' instead." + ) - return feature_matrix \ No newline at end of file + return feature_matrix diff --git a/fiora/GNN/CovariateFeatureEncoder.py b/fiora/GNN/CovariateFeatureEncoder.py index 456faf0..b2890b7 100644 --- a/fiora/GNN/CovariateFeatureEncoder.py +++ b/fiora/GNN/CovariateFeatureEncoder.py @@ -1,50 +1,73 @@ - -import torch -import numpy as np +import torch from fiora.MOL.constants import ORDERED_ELEMENT_LIST_WITH_HYDROGEN + class CovariateFeatureEncoder: - def __init__(self, feature_list=["collision_energy", "molecular_weight", "precursor_mode", "instrument", "element_composition"], sets_overwrite: dict|None=None): + def __init__( + self, + feature_list=[ + "collision_energy", + "molecular_weight", + "precursor_mode", + "instrument", + "element_composition", + ], + sets_overwrite: dict | None = None, + ): if "ce_steps" in feature_list: - raise ValueError("'ce_steps' is not meant as a setup feature. Remove from feature_list") + raise ValueError( + "'ce_steps' is not meant as a setup feature. Remove from feature_list" + ) self.encoding_dim = 0 self.feature_list = feature_list self.categorical_sets = { - "instrument": ["HCD", "Q-TOF", "IT-FT/ion trap with FTMS", "IT/ion trap"], # "IT-FT/ion trap with FTMS", "IT/ion trap", "QqQ", "QqQ/triple quadrupole" - "precursor_mode": ["[M+H]+", "[M-H]-"] + "instrument": [ + "HCD", + "Q-TOF", + "IT-FT/ion trap with FTMS", + "IT/ion trap", + ], # "IT-FT/ion trap with FTMS", "IT/ion trap", "QqQ", "QqQ/triple quadrupole" + "precursor_mode": ["[M+H]+", "[M-H]-"], } if sets_overwrite: for new_set, new_categories in sets_overwrite.items(): self.categorical_sets[new_set] = new_categories - - self.continuous_set = { - "collision_energy", - "molecular_weight" - } + + self.continuous_set = {"collision_energy", "molecular_weight"} self.normalize_features = { "collision_energy": {"min": 0, "max": 100, "transform": "linear"}, - "molecular_weight": {"min": 0, "max": 1000, "transform": "linear"} + "molecular_weight": {"min": 0, "max": 1000, "transform": "linear"}, } - - self.reduced_categorical_features = ["instrument"] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS) + + self.reduced_categorical_features = [ + "instrument" + ] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS) self.one_hot_mapper = {} for feature in self.feature_list: if feature in self.categorical_sets.keys(): variables = self.categorical_sets[feature] - self.one_hot_mapper[feature] = dict(zip(variables, range(self.encoding_dim, len(variables) + self.encoding_dim))) + self.one_hot_mapper[feature] = dict( + zip( + variables, + range(self.encoding_dim, len(variables) + self.encoding_dim), + ) + ) self.encoding_dim += len(variables) if feature in self.reduced_categorical_features: self.encoding_dim += 1 if feature in self.continuous_set: self.one_hot_mapper[feature] = self.encoding_dim self.encoding_dim += 1 - + if "element_composition" in self.feature_list: self.one_hot_mapper["element_composition"] = { - element: idx for idx, element in enumerate(ORDERED_ELEMENT_LIST_WITH_HYDROGEN, start=self.encoding_dim) + element: idx + for idx, element in enumerate( + ORDERED_ELEMENT_LIST_WITH_HYDROGEN, start=self.encoding_dim + ) } # Note that element composition is using int numbers and not one hot mapping. But the index is still correct. self.encoding_dim += len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN) - + def encode(self, dim0, metadata, G=None): feature_matrix = torch.zeros(dim0, self.encoding_dim, dtype=torch.float32) for feature in self.feature_list: @@ -53,44 +76,69 @@ def encode(self, dim0, metadata, G=None): if value in self.categorical_sets[feature]: feature_matrix[:, self.one_hot_mapper[feature][value]] = 1.0 else: - feature_matrix[:, self.one_hot_mapper[feature][list(self.categorical_sets[feature])[-1]] + 1] = 1.0 - + feature_matrix[ + :, + self.one_hot_mapper[feature][ + list(self.categorical_sets[feature])[-1] + ] + + 1, + ] = 1.0 + elif feature in self.continuous_set: value = metadata[feature] if feature in self.normalize_features.keys(): - value = (value - self.normalize_features[feature]["min"]) / (self.normalize_features[feature]["max"] - self.normalize_features[feature]["min"]) + value = (value - self.normalize_features[feature]["min"]) / ( + self.normalize_features[feature]["max"] + - self.normalize_features[feature]["min"] + ) feature_matrix[:, self.one_hot_mapper[feature]] = value feature_matrix = torch.clamp(feature_matrix, 0.0, 1.0) elif feature == "element_composition": if G is None: - raise ValueError("Graph G must be provided to encode 'element_composition'") + raise ValueError( + "Graph G must be provided to encode 'element_composition'" + ) element_composition = self.get_element_composition(G) for idx, element in enumerate(ORDERED_ELEMENT_LIST_WITH_HYDROGEN): - feature_matrix[:, self.one_hot_mapper["element_composition"][element]] = element_composition[idx] + feature_matrix[ + :, self.one_hot_mapper["element_composition"][element] + ] = element_composition[idx] return feature_matrix - + def normalize_collision_steps(self, ce_steps): - norm_ce = lambda x: (x - self.normalize_features["collision_energy"]["min"]) / (self.normalize_features["collision_energy"]["max"] - self.normalize_features["collision_energy"]["min"]) + norm_ce = lambda x: ( + (x - self.normalize_features["collision_energy"]["min"]) + / ( + self.normalize_features["collision_energy"]["max"] + - self.normalize_features["collision_energy"]["min"] + ) + ) ce_steps = [norm_ce(x) for x in ce_steps] return ce_steps - + @staticmethod def get_element_composition(G): # Initialize composition vector with zeros - element_composition = torch.zeros(len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN), dtype=torch.float32) + element_composition = torch.zeros( + len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN), dtype=torch.float32 + ) # Iterate through nodes in the graph for node in G.nodes: - atom = G.nodes[node]['atom'] + atom = G.nodes[node]["atom"] symbol = atom.GetSymbol() # Get the atomic symbol if symbol in ORDERED_ELEMENT_LIST_WITH_HYDROGEN: - index = ORDERED_ELEMENT_LIST_WITH_HYDROGEN.index(symbol) # Find the index of the element + index = ORDERED_ELEMENT_LIST_WITH_HYDROGEN.index( + symbol + ) # Find the index of the element element_composition[index] += 1 # Increment the count for the element # Add hydrogens explicitly hydrogens = atom.GetTotalNumHs() - hydrogen_index = ORDERED_ELEMENT_LIST_WITH_HYDROGEN.index('H') # Ensure 'H' is in ORDERED_ELEMENT_LIST + hydrogen_index = ORDERED_ELEMENT_LIST_WITH_HYDROGEN.index( + "H" + ) # Ensure 'H' is in ORDERED_ELEMENT_LIST element_composition[hydrogen_index] += hydrogens - return element_composition \ No newline at end of file + return element_composition diff --git a/fiora/GNN/Datasets.py b/fiora/GNN/Datasets.py index 33aeb2d..0fe23b3 100644 --- a/fiora/GNN/Datasets.py +++ b/fiora/GNN/Datasets.py @@ -3,7 +3,7 @@ import numpy as np from torch.utils.data import Dataset -''' +""" class AtomAromaticityData(Dataset): def __init__(self, df) -> None: self.X = np.concatenate(df["features"].values, dtype='float32') @@ -17,18 +17,25 @@ def __getitem__(self, idx): def num_features(self): return self.X.shape[1] -''' - +""" class AtomAromaticityData(Dataset): def __init__(self, df) -> None: - self.X = df["features"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values - self.A = df["Atilde"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values - self.y = df["is_aromatic"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values*1 + self.X = ( + df["features"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values + ) + self.A = ( + df["Atilde"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values + ) + self.y = ( + df["is_aromatic"] + .apply(lambda x: torch.tensor(x, dtype=torch.float32)) + .values + * 1 + ) - def __len__(self): return len(self.X) @@ -38,11 +45,16 @@ def __getitem__(self, idx): def num_features(self): return self.X[0].shape[1] + class SimpleNodeData(Dataset): - def __init__(self, data: pd.Series, feature_tag: str, label: str, device="cpu") -> None: - self.X = torch.cat(data.apply(lambda x: getattr(x, feature_tag).to(device)).to_list()) + def __init__( + self, data: pd.Series, feature_tag: str, label: str, device="cpu" + ) -> None: + self.X = torch.cat( + data.apply(lambda x: getattr(x, feature_tag).to(device)).to_list() + ) self.y = torch.cat(data.apply(lambda x: getattr(x, label).to(device)).to_list()) - + def __len__(self): return len(self.X) @@ -51,17 +63,18 @@ def __getitem__(self, idx): def num_features(self): return self.X.shape[1] - - + + class NodeSingleLabelData(Dataset): - def __init__(self, data: pd.Series, feature_tag: str, adj_tag: str, label: str, device="cpu") -> None: + def __init__( + self, data: pd.Series, feature_tag: str, adj_tag: str, label: str, device="cpu" + ) -> None: self.X = data.apply(lambda x: getattr(x, feature_tag).to(device)).values self.A = data.apply(lambda x: getattr(x, adj_tag).to(device)).values self.num_nodes = np.array(list(map(lambda x: x.shape[0], self.X))) self.y = data.apply(lambda x: getattr(x, label).to(device)).values - - + def __len__(self): return len(self.X) @@ -70,39 +83,74 @@ def __getitem__(self, idx): def num_features(self): return self.X[0].shape[1] - + class EdgeSingleLabelData(Dataset): - def __init__(self, data: pd.Series, feature_tag: str, left_tag: str, right_tag: str, adj_tag: str, edge_feature_tag: str, static_feature_tag: str, label: str, validation_mask_tag: str, group_id: str, device="cpu") -> None: + def __init__( + self, + data: pd.Series, + feature_tag: str, + left_tag: str, + right_tag: str, + adj_tag: str, + edge_feature_tag: str, + static_feature_tag: str, + label: str, + validation_mask_tag: str, + group_id: str, + device="cpu", + ) -> None: self.X = data.apply(lambda x: getattr(x, feature_tag).to(device)).values self.A = data.apply(lambda x: getattr(x, adj_tag).to(device)).values - self.AL = data.apply(lambda x: getattr(x, left_tag).to(device)).values # matrix to list all nodes to the left of an edge - self.AR = data.apply(lambda x: getattr(x, right_tag).to(device)).values # matrix to list all nodes to the right of an edge + self.AL = data.apply( + lambda x: getattr(x, left_tag).to(device) + ).values # matrix to list all nodes to the left of an edge + self.AR = data.apply( + lambda x: getattr(x, right_tag).to(device) + ).values # matrix to list all nodes to the right of an edge self.num_nodes = np.array(list(map(lambda x: x.shape[0], self.X))) self.y = data.apply(lambda x: getattr(x, label).to(device)).values self.num_edges = np.array(list(map(lambda x: x.shape[0], self.y))) - self.edge_features = data.apply(lambda x: getattr(x, edge_feature_tag).to(device)).values - self.static_features = data.apply(lambda x: getattr(x, static_feature_tag).to(device)).values - self.validation_mask = data.apply(lambda x: getattr(x, validation_mask_tag).to(device)).values + self.edge_features = data.apply( + lambda x: getattr(x, edge_feature_tag).to(device) + ).values + self.static_features = data.apply( + lambda x: getattr(x, static_feature_tag).to(device) + ).values + self.validation_mask = data.apply( + lambda x: getattr(x, validation_mask_tag).to(device) + ).values self.group_id = data.apply(lambda x: getattr(x, group_id)).values - - + def __len__(self): return len(self.X) def __getitem__(self, idx): - return [self.X[idx], self.A[idx], self.num_nodes[idx], self.num_edges[idx], self.edge_features[idx], self.static_features[idx], self.y[idx], self.AL[idx], self.AR[idx], self.validation_mask[idx]] + return [ + self.X[idx], + self.A[idx], + self.num_nodes[idx], + self.num_edges[idx], + self.edge_features[idx], + self.static_features[idx], + self.y[idx], + self.AL[idx], + self.AR[idx], + self.validation_mask[idx], + ] - def num_features(self): # Number of node features + def num_features(self): # Number of node features return self.X[0].shape[1] - - def num_static_features(self): # Number of additional features concatenated for the edge classification + + def num_static_features( + self, + ): # Number of additional features concatenated for the edge classification return self.static_features[0].shape[1] - + def get_unique_groups(self): return np.unique(self.group_id) - + def get_indices_of_groups(self, groups): indices = np.array([], dtype=int) for g in groups: @@ -110,7 +158,8 @@ def get_indices_of_groups(self, groups): indices = np.concatenate((indices, ids), axis=0) return indices -''' + +""" def collate_graph_batch(batch): X, A, no_nodes, y = zip(*batch) max_nodes = np.max(no_nodes) @@ -120,77 +169,103 @@ def collate_graph_batch(batch): A = torch.stack(list(map(pad_matrix, A, pad_sizes))) y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=0) return X, A, y -''' +""" + def collate_graph_batch(batch): X, A, num_nodes, y = zip(*batch) max_nodes = np.max(num_nodes) pad_sizes = max_nodes - num_nodes X = torch.nn.utils.rnn.pad_sequence(X, batch_first=True, padding_value=0) - pad_matrix = lambda x, y: torch.nn.functional.pad(x, (0,y,0,y), value=0) + pad_matrix = lambda x, y: torch.nn.functional.pad(x, (0, y, 0, y), value=0) A = torch.stack(list(map(pad_matrix, A, pad_sizes))) y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=0) - node_mask = torch.zeros((X.shape[0], max_nodes), dtype=torch.bool).to(X.get_device()) + node_mask = torch.zeros((X.shape[0], max_nodes), dtype=torch.bool).to( + X.get_device() + ) adj_mask = torch.zeros((X.shape[0], max_nodes, max_nodes), dtype=torch.bool) for i in range(node_mask.shape[0]): - node_mask[i, :num_nodes[i]] = 1 - adj_mask[i, :num_nodes[i], :num_nodes[i]] = 1 - + node_mask[i, : num_nodes[i]] = 1 + adj_mask[i, : num_nodes[i], : num_nodes[i]] = 1 + batch_record = { - 'X': X, - 'A': A, - 'y': y, - 'node_mask': node_mask, - 'adj_mask': adj_mask, - 'num_of_nodes': torch.tensor(list(num_nodes)).unsqueeze(dim=1) + "X": X, + "A": A, + "y": y, + "node_mask": node_mask, + "adj_mask": adj_mask, + "num_of_nodes": torch.tensor(list(num_nodes)).unsqueeze(dim=1), } return batch_record def collate_graph_edge_batch(batch): - X, A, num_nodes, num_edges, edge_features, static_features, y, AL, AR, validation_bits = zip(*batch) + ( + X, + A, + num_nodes, + num_edges, + edge_features, + static_features, + y, + AL, + AR, + validation_bits, + ) = zip(*batch) max_nodes = np.max(num_nodes) pad_sizes = max_nodes - num_nodes max_edges = np.max(num_edges) - X = torch.nn.utils.rnn.pad_sequence(X, batch_first=True, padding_value=0) - edge_features = torch.nn.utils.rnn.pad_sequence(edge_features, batch_first=True, padding_value=0) - static_features = torch.nn.utils.rnn.pad_sequence(static_features, batch_first=True, padding_value=0) - pad_matrix = lambda x, y: torch.nn.functional.pad(x, (0,y,0,y), value=0) + edge_features = torch.nn.utils.rnn.pad_sequence( + edge_features, batch_first=True, padding_value=0 + ) + static_features = torch.nn.utils.rnn.pad_sequence( + static_features, batch_first=True, padding_value=0 + ) + + def pad_matrix(x, y): + return torch.nn.functional.pad(x, (0, y, 0, y), value=0) + A = torch.stack(list(map(pad_matrix, A, pad_sizes))) y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=0) - - pad_matrix = lambda x, y, z: torch.nn.functional.pad(x, (0,y,0,z), value=0) # Pad helper matrices to maximum size + + def pad_matrix(x, y, z): + return torch.nn.functional.pad( + x, (0, y, 0, z), value=0 + ) # Pad helper matrices to maximum size + pad_edge_sizes = max_edges - num_edges AL = torch.stack(list(map(pad_matrix, AL, pad_sizes, pad_edge_sizes))) AR = torch.stack(list(map(pad_matrix, AR, pad_sizes, pad_edge_sizes))) - - node_mask = torch.zeros((X.shape[0], max_nodes), dtype=torch.bool).to(X.get_device()) + + node_mask = torch.zeros((X.shape[0], max_nodes), dtype=torch.bool).to( + X.get_device() + ) adj_mask = torch.zeros((X.shape[0], max_nodes, max_nodes), dtype=torch.bool) - + y_mask = torch.zeros((y.shape[0], max_edges), dtype=torch.bool).to(X.get_device()) validation_mask = torch.zeros((y.shape[0], max_edges), dtype=torch.bool) for i in range(node_mask.shape[0]): - node_mask[i, :num_nodes[i]] = 1 - adj_mask[i, :num_nodes[i], :num_nodes[i]] = 1 - y_mask[i, :num_edges[i]] = 1 - validation_mask[i, :num_edges[i]] = validation_bits[i].flatten() - + node_mask[i, : num_nodes[i]] = 1 + adj_mask[i, : num_nodes[i], : num_nodes[i]] = 1 + y_mask[i, : num_edges[i]] = 1 + validation_mask[i, : num_edges[i]] = validation_bits[i].flatten() + batch_record = { - 'X': X, - 'A': A, - 'y': y, - 'AL': AL, - 'AR': AR, - 'node_mask': node_mask, - 'adj_mask': adj_mask, - 'y_mask': y_mask, - 'edge_features': edge_features, - 'static_features': static_features, - 'validation_mask': validation_mask, - 'num_of_nodes': torch.tensor(list(num_nodes)).unsqueeze(dim=1), - 'num_of_edges': torch.tensor(list(num_edges)).unsqueeze(dim=1) + "X": X, + "A": A, + "y": y, + "AL": AL, + "AR": AR, + "node_mask": node_mask, + "adj_mask": adj_mask, + "y_mask": y_mask, + "edge_features": edge_features, + "static_features": static_features, + "validation_mask": validation_mask, + "num_of_nodes": torch.tensor(list(num_nodes)).unsqueeze(dim=1), + "num_of_edges": torch.tensor(list(num_edges)).unsqueeze(dim=1), } - return batch_record \ No newline at end of file + return batch_record diff --git a/fiora/GNN/EdgePropertyPredictor.py b/fiora/GNN/EdgePropertyPredictor.py index 878ff61..d2dc846 100644 --- a/fiora/GNN/EdgePropertyPredictor.py +++ b/fiora/GNN/EdgePropertyPredictor.py @@ -5,43 +5,74 @@ from fiora.MOL.constants import ORDERED_ELEMENT_LIST_WITH_HYDROGEN + class EdgePropertyPredictor(torch.nn.Module): - def __init__(self, edge_feature_dict: Dict, hidden_features: int, static_features: int, out_dimension: int, dense_depth: int=0, dense_dim: int=None, embedding_dim: int=200, embedding_aggregation_type: str='concat', residual_connections: bool=False, subgraph_features: bool=False, pooling_func: Literal["avg", "max"]="avg", input_dropout: float=0, latent_dropout: float=0) -> None: - ''' Initialize the EdgePropertyPredictor model. - Args: - edge_feature_dict (dict): Dictionary containing edge feature information. - hidden_features (int): Number of hidden features for each layer. - static_features (int): Number of static features to be concatenated. - out_dimension (int): Output dimension of the model. - dense_depth (int, optional): Number of dense layers. Defaults to 0. - dense_dim (int, optional): Dimension of the dense layers. If None, it will be set to the number of input features. Defaults to None. - embedding_dim (int, optional): Dimension of the edge embeddings. Defaults to 200. - embedding_aggregation_type (str, optional): Type of aggregation for edge embeddings. Defaults to 'concat'. - residual_connections (bool, optional): Whether to use residual connections. Defaults to False. - input_dropout (float, optional): Dropout rate for input features. Defaults to 0. - latent_dropout (float, optional): Dropout rate for latent features. Defaults to 0. - ''' + def __init__( + self, + edge_feature_dict: Dict, + hidden_features: int, + static_features: int, + out_dimension: int, + dense_depth: int = 0, + dense_dim: int = None, + embedding_dim: int = 200, + embedding_aggregation_type: str = "concat", + residual_connections: bool = False, + subgraph_features: bool = False, + pooling_func: Literal["avg", "max"] = "avg", + input_dropout: float = 0, + latent_dropout: float = 0, + ) -> None: + """Initialize the EdgePropertyPredictor model. + Args: + edge_feature_dict (dict): Dictionary containing edge feature information. + hidden_features (int): Number of hidden features for each layer. + static_features (int): Number of static features to be concatenated. + out_dimension (int): Output dimension of the model. + dense_depth (int, optional): Number of dense layers. Defaults to 0. + dense_dim (int, optional): Dimension of the dense layers. If None, it will be set to the number of input features. Defaults to None. + embedding_dim (int, optional): Dimension of the edge embeddings. Defaults to 200. + embedding_aggregation_type (str, optional): Type of aggregation for edge embeddings. Defaults to 'concat'. + residual_connections (bool, optional): Whether to use residual connections. Defaults to False. + input_dropout (float, optional): Dropout rate for input features. Defaults to 0. + latent_dropout (float, optional): Dropout rate for latent features. Defaults to 0. + """ super().__init__() self.activation = torch.nn.ELU() - #self.edge_embedding = FeatureEmbedding(edge_feature_dict, embedding_dim, aggregation_type=embedding_aggregation_type) + # self.edge_embedding = FeatureEmbedding(edge_feature_dict, embedding_dim, aggregation_type=embedding_aggregation_type) self.input_dropout = torch.nn.Dropout(input_dropout) self.latent_dropout = torch.nn.Dropout(latent_dropout) self.residual_connections = residual_connections self.subgraph_features = subgraph_features - self.pooling_func = geom_nn.global_mean_pool if pooling_func == "avg" else geom_nn.global_max_pool - num_subgraph_features = 2*len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN) + hidden_features*2 if subgraph_features else 0 + self.pooling_func = ( + geom_nn.global_mean_pool + if pooling_func == "avg" + else geom_nn.global_max_pool + ) + num_subgraph_features = ( + 2 * len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN) + hidden_features * 2 + if subgraph_features + else 0 + ) dense_layers = [] - num_features = hidden_features*2 + num_subgraph_features + embedding_dim + static_features + num_features = ( + hidden_features * 2 + + num_subgraph_features + + embedding_dim + + static_features + ) hidden_dimension = dense_dim if dense_dim is not None else num_features if hidden_dimension != num_features and residual_connections: - raise NotImplementedError("Residual connections require the hidden dimension to match the input dimension.") + raise NotImplementedError( + "Residual connections require the hidden dimension to match the input dimension." + ) for _ in range(dense_depth): dense_layers += [torch.nn.Linear(num_features, hidden_dimension)] num_features = hidden_dimension self.dense_layers = torch.nn.ModuleList(dense_layers) - + self.output_layer = torch.nn.Linear(num_features, out_dimension) def concat_node_pairs(self, X, batch): @@ -49,25 +80,33 @@ def concat_node_pairs(self, X, batch): Concatenates node pairs and optionally adds subgraph features, including element composition. This version includes debug messages to trace tensor shapes. """ - + src, dst = batch["edge_index"] - + # 1. Node Pair Concatenation X_src = X[src] X_dst = X[dst] node_pairs = torch.cat([X_src, X_dst], dim=1) - if self.subgraph_features: + if self.subgraph_features: # 2. Subgraph Feature Pooling num_edges = batch.edge_index.size(1) edge_batch_map = batch.batch[batch.edge_index[0]] - left_indices = batch.subgraph_idx_left + batch.ptr[edge_batch_map].unsqueeze(1) - right_indices = batch.subgraph_idx_right + batch.ptr[edge_batch_map].unsqueeze(1) + left_indices = batch.subgraph_idx_left + batch.ptr[ + edge_batch_map + ].unsqueeze(1) + right_indices = batch.subgraph_idx_right + batch.ptr[ + edge_batch_map + ].unsqueeze(1) - left_batch_vec = torch.arange(num_edges, device=X.device).repeat_interleave(left_indices.size(1)) - right_batch_vec = torch.arange(num_edges, device=X.device).repeat_interleave(right_indices.size(1)) + left_batch_vec = torch.arange(num_edges, device=X.device).repeat_interleave( + left_indices.size(1) + ) + right_batch_vec = torch.arange( + num_edges, device=X.device + ).repeat_interleave(right_indices.size(1)) left_flat = left_indices.flatten() right_flat = right_indices.flatten() @@ -78,28 +117,40 @@ def concat_node_pairs(self, X, batch): pooled_right = torch.zeros(num_edges, X.size(1), device=X.device) if left_mask.any(): - pooled_left = self.pooling_func(X[left_flat[left_mask]], left_batch_vec[left_mask], size=num_edges) + pooled_left = self.pooling_func( + X[left_flat[left_mask]], left_batch_vec[left_mask], size=num_edges + ) if right_mask.any(): - pooled_right = self.pooling_func(X[right_flat[right_mask]], right_batch_vec[right_mask], size=num_edges) - + pooled_right = self.pooling_func( + X[right_flat[right_mask]], + right_batch_vec[right_mask], + size=num_edges, + ) + subgraph_features = torch.cat([pooled_left, pooled_right], dim=1) # 3. Final Concatenation with Subgraph Features edge_elem_comp = batch["edge_elem_comp"] - + # This is the likely point of failure - node_pairs = torch.cat([node_pairs, subgraph_features, edge_elem_comp], dim=1) + node_pairs = torch.cat( + [node_pairs, subgraph_features, edge_elem_comp], dim=1 + ) return node_pairs - - def forward(self, X, batch): + + def forward(self, X, batch): # Melt node features into a stack of edges (represented by left and right node) X = self.concat_node_pairs(X, batch) - - # Add edge features and static features - edge_features = batch["edge_embedding"] #self.edge_embedding(batch["edge_attr"]) + + # Add edge features and static features + edge_features = batch[ + "edge_embedding" + ] # self.edge_embedding(batch["edge_attr"]) edge_features = self.input_dropout(edge_features) - X = torch.cat([X, edge_features, batch["static_edge_features"]], axis=-1) #self.input_dropout(batch["static_edge_features"]) - + X = torch.cat( + [X, edge_features, batch["static_edge_features"]], axis=-1 + ) # self.input_dropout(batch["static_edge_features"]) + # Apply fully connected layers for layer in self.dense_layers: X_skip = X @@ -109,4 +160,4 @@ def forward(self, X, batch): X = X + X_skip logits = self.output_layer(X) - return logits \ No newline at end of file + return logits diff --git a/fiora/GNN/FeatureEmbedding.py b/fiora/GNN/FeatureEmbedding.py index 74bf414..7d804ae 100644 --- a/fiora/GNN/FeatureEmbedding.py +++ b/fiora/GNN/FeatureEmbedding.py @@ -4,84 +4,108 @@ class FeatureEmbedding(torch.nn.Module): - def __init__(self, feature_dict: Dict[str, int], dim=200, aggregation_type=Literal['concat', 'sum']) -> None: + def __init__( + self, + feature_dict: Dict[str, int], + dim=200, + aggregation_type=Literal["concat", "sum"], + ) -> None: super().__init__() - + self.aggregation_type = aggregation_type self.feature_dim = dim - if aggregation_type == 'concat': + if aggregation_type == "concat": num_features = len(feature_dict.keys()) self.feature_dim = int(dim / num_features) self.dim = self.feature_dim * num_features if self.dim != dim: - warnings.warn(f"Desired embedding dimension not cleanly dividable by the number of features. Reducing dimension from {dim} to {self.dim}.") - elif aggregation_type == 'sum': + warnings.warn( + f"Desired embedding dimension not cleanly dividable by the number of features. Reducing dimension from {dim} to {self.dim}." + ) + elif aggregation_type == "sum": self.dim = dim self.feature_dim = dim else: - raise NameError(f"Unknown aggregation type selected. Valid types are {aggregation_type}.") - self.embeddings = torch.nn.ModuleList([torch.nn.Embedding(num, self.feature_dim) for feat, num in feature_dict.items()]) # Use ModuleDict instead? - + raise NameError( + f"Unknown aggregation type selected. Valid types are {aggregation_type}." + ) + self.embeddings = torch.nn.ModuleList( + [ + torch.nn.Embedding(num, self.feature_dim) + for feat, num in feature_dict.items() + ] + ) # Use ModuleDict instead? + def get_embedding_dimension(self): return self.dim - - + def forward(self, features, feature_mask=None): node_embeddings = [] for i, embedding in enumerate(self.embeddings): values = features[:, i] node_embeddings.append(embedding(values)) - - if self.aggregation_type == 'sum': + + if self.aggregation_type == "sum": embedded_features = torch.sum(torch.stack(node_embeddings, dim=-1), dim=-1) - elif self.aggregation_type == 'concat': + elif self.aggregation_type == "concat": embedded_features = torch.cat(node_embeddings, dim=-1) - + if feature_mask is not None: embedded_features = embedded_features * feature_mask.unsqueeze(-1) - - return embedded_features - + return embedded_features class FeatureEmbeddingPacked(torch.nn.Module): - def __init__(self, feature_dict: Dict[str, int], dim=200, aggregation_type=Literal['concat', 'sum']) -> None: + def __init__( + self, + feature_dict: Dict[str, int], + dim=200, + aggregation_type=Literal["concat", "sum"], + ) -> None: super().__init__() - + self.aggregation_type = aggregation_type self.feature_dim = dim - if aggregation_type == 'concat': + if aggregation_type == "concat": num_features = len(feature_dict.keys()) self.feature_dim = int(dim / num_features) self.dim = self.feature_dim * num_features if self.dim != dim: - warnings.warn(f"Desired embedding dimension not cleanly dividable by the number of features. Reducing dimension from {dim} to {self.dim}.") - elif aggregation_type == 'sum': + warnings.warn( + f"Desired embedding dimension not cleanly dividable by the number of features. Reducing dimension from {dim} to {self.dim}." + ) + elif aggregation_type == "sum": self.dim = dim self.feature_dim = dim else: - raise NameError(f"Unknown aggregation type selected. Valid types are {aggregation_type}.") - self.embeddings = torch.nn.ModuleList([torch.nn.Embedding(num, self.feature_dim) for feat, num in feature_dict.items()]) # Use ModuleDict instead? - + raise NameError( + f"Unknown aggregation type selected. Valid types are {aggregation_type}." + ) + self.embeddings = torch.nn.ModuleList( + [ + torch.nn.Embedding(num, self.feature_dim) + for feat, num in feature_dict.items() + ] + ) # Use ModuleDict instead? + def get_embedding_dimension(self): return self.dim - - + def forward(self, features, feature_mask=None): node_embeddings = [] for i, embedding in enumerate(self.embeddings): values = features[:, :, i] node_embeddings.append(embedding(values)) - - if self.aggregation_type == 'sum': + + if self.aggregation_type == "sum": embedded_features = torch.sum(torch.stack(node_embeddings, dim=-1), dim=-1) - elif self.aggregation_type == 'concat': + elif self.aggregation_type == "concat": embedded_features = torch.cat(node_embeddings, dim=-1) - + if feature_mask is not None: embedded_features = embedded_features * feature_mask.unsqueeze(-1) - - return embedded_features \ No newline at end of file + + return embedded_features diff --git a/fiora/GNN/FioraModel.py b/fiora/GNN/FioraModel.py index aba1db2..55e85cd 100644 --- a/fiora/GNN/FioraModel.py +++ b/fiora/GNN/FioraModel.py @@ -1,8 +1,7 @@ import torch -import torch_geometric.nn as geom_nn # Fiora GNN Modules -from fiora.GNN.FeatureEmbedding import FeatureEmbedding, FeatureEmbeddingPacked +from fiora.GNN.FeatureEmbedding import FeatureEmbedding from fiora.GNN.GNN import GNN from fiora.GNN.GraphPropertyPredictor import GraphPropertyPredictor from fiora.GNN.EdgePropertyPredictor import EdgePropertyPredictor @@ -15,44 +14,112 @@ class FioraModel(torch.nn.Module): def __init__(self, model_params: Dict) -> None: - ''' Initialize the FioraModel with the given parameters. - Args: - model_params (Dict): Dictionary containing model parameters such as node/edge feature layouts, embedding dimensions, hidden dimensions, etc. - ''' + """Initialize the FioraModel with the given parameters. + Args: + model_params (Dict): Dictionary containing model parameters such as node/edge feature layouts, embedding dimensions, hidden dimensions, etc. + """ super().__init__() - + self._version_control_model_params(model_params) self.edge_dim = model_params["output_dimension"] - self.node_embedding = FeatureEmbedding(feature_dict=model_params["node_feature_layout"], dim=model_params["embedding_dimension"], aggregation_type=model_params["embedding_aggregation"]) - self.edge_embedding = FeatureEmbedding(feature_dict=model_params["edge_feature_layout"], dim=model_params["embedding_dimension"], aggregation_type=model_params["embedding_aggregation"]) - self.GNN_module = GNN(hidden_features=model_params["hidden_dimension"], depth=model_params["depth"], embedding_dim=self.node_embedding.get_embedding_dimension(), embedding_aggregation_type=model_params["embedding_aggregation"], gnn_type=model_params["gnn_type"], layer_norm=model_params["layer_norm"], residual_connections=model_params["residual_connections"], layer_stacking=model_params["layer_stacking"], input_dropout=model_params["input_dropout"], latent_dropout=model_params["latent_dropout"]) - self.edge_module = EdgePropertyPredictor(edge_feature_dict=model_params["edge_feature_layout"], hidden_features=self.GNN_module.get_embedding_dimension(), static_features=model_params["static_feature_dimension"], out_dimension=model_params["output_dimension"], dense_depth=model_params["dense_layers"], dense_dim=model_params["dense_dim"], embedding_dim=self.edge_embedding.get_embedding_dimension(), embedding_aggregation_type=model_params["embedding_aggregation"], residual_connections=model_params["residual_connections"], subgraph_features=model_params["subgraph_features"], pooling_func=model_params["pooling_func"], input_dropout=model_params["input_dropout"], latent_dropout=model_params["latent_dropout"]) - self.precursor_module = GraphPropertyPredictor(hidden_features=self.GNN_module.get_embedding_dimension(), static_features=model_params["static_feature_dimension"], out_dimension=1, dense_depth=model_params["dense_layers"], dense_dim=model_params["dense_dim"], residual_connections=model_params["residual_connections"], pooling_func=model_params["pooling_func"], input_dropout=model_params["input_dropout"], latent_dropout=model_params["latent_dropout"]) - + self.node_embedding = FeatureEmbedding( + feature_dict=model_params["node_feature_layout"], + dim=model_params["embedding_dimension"], + aggregation_type=model_params["embedding_aggregation"], + ) + self.edge_embedding = FeatureEmbedding( + feature_dict=model_params["edge_feature_layout"], + dim=model_params["embedding_dimension"], + aggregation_type=model_params["embedding_aggregation"], + ) + self.GNN_module = GNN( + hidden_features=model_params["hidden_dimension"], + depth=model_params["depth"], + embedding_dim=self.node_embedding.get_embedding_dimension(), + embedding_aggregation_type=model_params["embedding_aggregation"], + gnn_type=model_params["gnn_type"], + layer_norm=model_params["layer_norm"], + residual_connections=model_params["residual_connections"], + layer_stacking=model_params["layer_stacking"], + input_dropout=model_params["input_dropout"], + latent_dropout=model_params["latent_dropout"], + ) + self.edge_module = EdgePropertyPredictor( + edge_feature_dict=model_params["edge_feature_layout"], + hidden_features=self.GNN_module.get_embedding_dimension(), + static_features=model_params["static_feature_dimension"], + out_dimension=model_params["output_dimension"], + dense_depth=model_params["dense_layers"], + dense_dim=model_params["dense_dim"], + embedding_dim=self.edge_embedding.get_embedding_dimension(), + embedding_aggregation_type=model_params["embedding_aggregation"], + residual_connections=model_params["residual_connections"], + subgraph_features=model_params["subgraph_features"], + pooling_func=model_params["pooling_func"], + input_dropout=model_params["input_dropout"], + latent_dropout=model_params["latent_dropout"], + ) + self.precursor_module = GraphPropertyPredictor( + hidden_features=self.GNN_module.get_embedding_dimension(), + static_features=model_params["static_feature_dimension"], + out_dimension=1, + dense_depth=model_params["dense_layers"], + dense_dim=model_params["dense_dim"], + residual_connections=model_params["residual_connections"], + pooling_func=model_params["pooling_func"], + input_dropout=model_params["input_dropout"], + latent_dropout=model_params["latent_dropout"], + ) + if model_params["prepare_additional_layers"]: - self.RT_module = GraphPropertyPredictor(hidden_features=self.GNN_module.get_embedding_dimension(), static_features=model_params["static_rt_feature_dimension"], out_dimension=1, dense_depth=model_params["dense_layers"], dense_dim=model_params["dense_dim"], residual_connections=model_params["residual_connections"], pooling_func=model_params["pooling_func"], input_dropout=model_params["input_dropout"], latent_dropout=model_params["latent_dropout"]) - self.CCS_module = GraphPropertyPredictor(hidden_features=self.GNN_module.get_embedding_dimension(), static_features=model_params["static_rt_feature_dimension"], out_dimension=1, dense_depth=model_params["dense_layers"], dense_dim=model_params["dense_dim"], residual_connections=model_params["residual_connections"], pooling_func=model_params["pooling_func"], input_dropout=model_params["input_dropout"], latent_dropout=model_params["latent_dropout"]) - - + self.RT_module = GraphPropertyPredictor( + hidden_features=self.GNN_module.get_embedding_dimension(), + static_features=model_params["static_rt_feature_dimension"], + out_dimension=1, + dense_depth=model_params["dense_layers"], + dense_dim=model_params["dense_dim"], + residual_connections=model_params["residual_connections"], + pooling_func=model_params["pooling_func"], + input_dropout=model_params["input_dropout"], + latent_dropout=model_params["latent_dropout"], + ) + self.CCS_module = GraphPropertyPredictor( + hidden_features=self.GNN_module.get_embedding_dimension(), + static_features=model_params["static_rt_feature_dimension"], + out_dimension=1, + dense_depth=model_params["dense_layers"], + dense_dim=model_params["dense_dim"], + residual_connections=model_params["residual_connections"], + pooling_func=model_params["pooling_func"], + input_dropout=model_params["input_dropout"], + latent_dropout=model_params["latent_dropout"], + ) + self.set_transform("double_softmax") self.model_params = model_params - + def _version_control_model_params(self, model_params: Dict) -> None: - ''' Update model parameters to match the latest model version. - Args: - model_params (Dict): Dictionary containing model parameters. - ''' + """Update model parameters to match the latest model version. + Args: + model_params (Dict): Dictionary containing model parameters. + """ if "residual_connections" not in model_params: model_params["residual_connections"] = False if "layer_stacking" not in model_params: model_params["layer_stacking"] = False if "prepare_additional_layers" not in model_params: - model_params["prepare_additional_layers"] = True # Defaults to True, since old models have RT/CCS modules + model_params["prepare_additional_layers"] = ( + True # Defaults to True, since old models have RT/CCS modules + ) if "dense_dim" not in model_params: - model_params["dense_dim"] = None # None defaults to the number of input features (GNN output dimension) - if "subgraph_features" not in model_params: # No subgraph features in older models + model_params["dense_dim"] = ( + None # None defaults to the number of input features (GNN output dimension) + ) + if ( + "subgraph_features" not in model_params + ): # No subgraph features in older models model_params["subgraph_features"] = False if "pooling" not in model_params: model_params["pooling_func"] = "avg" @@ -61,19 +128,18 @@ def _version_control_model_params(self, model_params: Dict) -> None: return - def freeze_submodule(self, submodule_name: str): module = getattr(self, submodule_name) for param in module.parameters(): param.requires_grad = False - + def unfreeze_submodule(self, submodule_name: str): module = getattr(self, submodule_name) for param in module.parameters(): param.requires_grad = True - + def set_dropout_rate(self, input_dropout: float, latent_dropout: float) -> None: - + self.GNN_module.input_dropout.p = input_dropout self.GNN_module.latent_dropout.p = latent_dropout self.edge_module.input_dropout.p = input_dropout @@ -84,116 +150,142 @@ def set_dropout_rate(self, input_dropout: float, latent_dropout: float) -> None: self.RT_module.latent_dropout.p = latent_dropout self.CCS_module.input_dropout.p = input_dropout self.CCS_module.latent_dropout.p = latent_dropout - - def set_transform(self, transformation: Literal["softmax", "double_softmax", "off"]): + + def set_transform( + self, transformation: Literal["softmax", "double_softmax", "off"] + ): self.softmax = torch.nn.Softmax(dim=0) if transformation == "double_softmax": - self.transform = lambda y: 2. *self.softmax(y) # TODO make torch module + self.transform = lambda y: 2.0 * self.softmax(y) # TODO make torch module elif transformation == "softmax": self.transform = self.softmax elif transformation == "off": self.transform = torch.nn.Identity() else: raise ValueError(f"Unknown transformation type: {transformation}") - - ''' + """ Compile output is the heart of the fragment probability prediction. It combines edge/fragment prediction with precursor prediction for each individual graph/input and applies softmax. Then, all output values are stacked in a single dimension. - ''' + """ + def _compile_output(self, edge_values, graph_values, batch) -> torch.tensor: - output = torch.zeros(edge_values.shape[0] * edge_values.shape[1] + graph_values.shape[0] * 2, device=edge_values.device) + output = torch.zeros( + edge_values.shape[0] * edge_values.shape[1] + graph_values.shape[0] * 2, + device=edge_values.device, + ) batch_ptr = 0 segment_ptr = [0] # cumulative boundaries per-graph (len=num_graphs+1) - + # Map edges to graph index (repeat left nodes according to edge dimension and retrieve graph/batch index) - edge_graph_map = batch["batch"][torch.repeat_interleave(batch["edge_index"][0,:], self.edge_dim)] + edge_graph_map = batch["batch"][ + torch.repeat_interleave(batch["edge_index"][0, :], self.edge_dim) + ] for i in range(batch.num_graphs): - edges = edge_values.flatten()[edge_graph_map == i] # Retrieve edge_values for graph i - offset = edges.shape[0] + graph_values.shape[1] * 2 # Precursor prediction output is repeated to account for bi-directional edge occurances - output[batch_ptr:batch_ptr + offset,] = self.transform(torch.cat([edges, graph_values[i], graph_values[i]], axis=-1)) # concat and apply softmax + edges = edge_values.flatten()[ + edge_graph_map == i + ] # Retrieve edge_values for graph i + offset = ( + edges.shape[0] + graph_values.shape[1] * 2 + ) # Precursor prediction output is repeated to account for bi-directional edge occurances + output[batch_ptr : batch_ptr + offset,] = self.transform( + torch.cat([edges, graph_values[i], graph_values[i]], axis=-1) + ) # concat and apply softmax batch_ptr += offset segment_ptr.append(batch_ptr) - return output, torch.tensor(segment_ptr, device=edge_values.device, dtype=torch.long) - + return output, torch.tensor( + segment_ptr, device=edge_values.device, dtype=torch.long + ) + def get_graph_embedding(self, batch): batch["node_embedding"] = self.node_embedding(batch["x"]) batch["edge_embedding"] = self.edge_embedding(batch["edge_attr"]) X = self.GNN_module(batch) pooling_func = self.precursor_module.pooling_func return pooling_func(X, batch["batch"]) - def forward(self, batch, with_RT=False, with_CCS=False): # Embed node features batch["node_embedding"] = self.node_embedding(batch["x"]) batch["edge_embedding"] = self.edge_embedding(batch["edge_attr"]) - + X = self.GNN_module(batch) - + edge_values = self.edge_module(X, batch) graph_values = self.precursor_module(X, batch) - fragment_probs, segment_ptr = self._compile_output(edge_values, graph_values, batch) - - output = {'fragment_probs': fragment_probs, 'segment_ptr': segment_ptr} - + fragment_probs, segment_ptr = self._compile_output( + edge_values, graph_values, batch + ) + + output = {"fragment_probs": fragment_probs, "segment_ptr": segment_ptr} + if with_RT: rt_values = self.RT_module(X, batch, covariate_tag="static_rt_features") output["rt"] = rt_values - + if with_CCS: ccs_values = self.CCS_module(X, batch, covariate_tag="static_rt_features") output["ccs"] = ccs_values return output - + @classmethod - def load(cls, PATH: str) -> 'FioraModel': - - with open(PATH, 'rb') as f: + def load(cls, PATH: str) -> "FioraModel": + + with open(PATH, "rb") as f: model = dill.load(f) if not isinstance(model, cls): - raise ValueError(f'file {PATH} contains incorrect model class {type(model)}') + raise ValueError( + f"file {PATH} contains incorrect model class {type(model)}" + ) return model - + @classmethod - def load_from_state_dict(cls, PATH: str) -> 'FioraModel': + def load_from_state_dict(cls, PATH: str) -> "FioraModel": PARAMS_PATH = PATH.replace(".pt", "_params.json") STATE_PATH = PATH.replace(".pt", "_state.pt") - - with open(PARAMS_PATH, 'r') as fp: + + with open(PARAMS_PATH, "r") as fp: params = json.load(fp) model = FioraModel(params) - model.load_state_dict(torch.load(STATE_PATH, map_location=torch.serialization.default_restore_location, weights_only=True)) + model.load_state_dict( + torch.load( + STATE_PATH, + map_location=torch.serialization.default_restore_location, + weights_only=True, + ) + ) if not isinstance(model, cls): - raise ValueError(f'file {PATH} contains incorrect model class {type(model)}') + raise ValueError( + f"file {PATH} contains incorrect model class {type(model)}" + ) return model - - def save(self, PATH: str, dev: str="cpu") -> None: - + + def save(self, PATH: str, dev: str = "cpu") -> None: + prev_device = next(self.parameters()).device - + # Set device to cpu for saving self.to(dev) - with open(PATH, 'wb') as f: + with open(PATH, "wb") as f: dill.dump(self.to(dev), f) - + # Save state_dict and parameters as backup - PATH = '.'.join(PATH.split('.')[:-1]) + '_params.json' - with open(PATH, 'w') as fp: + PATH = ".".join(PATH.split(".")[:-1]) + "_params.json" + with open(PATH, "w") as fp: json.dump(self.model_params, fp) PATH = PATH.replace("_params.json", "_state.pt") torch.save(self.to(dev).state_dict(), PATH) - - #Reset to previous device - self.to(prev_device) \ No newline at end of file + + # Reset to previous device + self.to(prev_device) diff --git a/fiora/GNN/GNN.py b/fiora/GNN/GNN.py index aaec2f9..a28689a 100644 --- a/fiora/GNN/GNN.py +++ b/fiora/GNN/GNN.py @@ -2,53 +2,65 @@ import torch_geometric.nn as geom_nn from typing import Literal -''' +""" Geometric Models -''' +""" GeometricLayer = { "GraphConv": { "Layer": geom_nn.GraphConv, "divide_output_dim": False, - "const_args": {'aggr': 'mean'}, - "batch_args": {'edge_index': 'edge_index'} + "const_args": {"aggr": "mean"}, + "batch_args": {"edge_index": "edge_index"}, }, "GAT": { "Layer": geom_nn.GATConv, "divide_output_dim": True, - "const_args": {'heads': 5}, - "batch_args": {'edge_index': 'edge_index', 'edge_attr': 'edge_embedding'} + "const_args": {"heads": 5}, + "batch_args": {"edge_index": "edge_index", "edge_attr": "edge_embedding"}, }, - "RGCNConv": { "Layer": geom_nn.RGCNConv, "divide_output_dim": False, - "const_args": {'aggr': 'mean', 'num_relations': 4}, - "batch_args": {'edge_index': 'edge_index', 'edge_type': 'edge_type'} + "const_args": {"aggr": "mean", "num_relations": 4}, + "batch_args": {"edge_index": "edge_index", "edge_type": "edge_type"}, }, - "TransformerConv": { "Layer": geom_nn.TransformerConv, "divide_output_dim": True, - "const_args": {'heads': 8, 'edge_dim': 300}, - "batch_args": {'edge_index': 'edge_index', 'edge_attr': 'edge_embedding'} + "const_args": {"heads": 8, "edge_dim": 300}, + "batch_args": {"edge_index": "edge_index", "edge_attr": "edge_embedding"}, }, - "CGConv": { "Layer": geom_nn.CGConv, "divide_output_dim": False, - "const_args": {'aggr': "mean"}, #, 'dim': 300}, - "batch_args": {'edge_index': 'edge_index', 'edge_attr': 'edge_embedding'} - } + "const_args": {"aggr": "mean"}, # , 'dim': 300}, + "batch_args": {"edge_index": "edge_index", "edge_attr": "edge_embedding"}, + }, } -''' +""" Graph Neural Network (GNN) Class -''' +""" + class GNN(torch.nn.Module): - def __init__(self, hidden_features: int, depth: int, embedding_dim: int=None, embedding_aggregation_type: str='concat', gnn_type: Literal["GraphConv", "GAT", "RGCNConv", "TransformerConv", "CGConv"]="RGCNConv", layer_norm: bool=False, residual_connections: bool=False, layer_stacking: bool=False, input_dropout: float=0, latent_dropout: float=0) -> None: - ''' Initialize the GNN model. + def __init__( + self, + hidden_features: int, + depth: int, + embedding_dim: int = None, + embedding_aggregation_type: str = "concat", + gnn_type: Literal[ + "GraphConv", "GAT", "RGCNConv", "TransformerConv", "CGConv" + ] = "RGCNConv", + layer_norm: bool = False, + residual_connections: bool = False, + layer_stacking: bool = False, + input_dropout: float = 0, + latent_dropout: float = 0, + ) -> None: + """Initialize the GNN model. Args: hidden_features (int): Number of hidden features for each layer. depth (int): Number of graph layers. @@ -58,7 +70,7 @@ def __init__(self, hidden_features: int, depth: int, embedding_dim: int=None, em residual_connections (bool, optional): Whether to use residual connections. Defaults to False. input_dropout (float, optional): Dropout rate for input features. Defaults to 0. latent_dropout (float, optional): Dropout rate for latent features. Defaults to 0. - ''' + """ super().__init__() @@ -70,25 +82,29 @@ def __init__(self, hidden_features: int, depth: int, embedding_dim: int=None, em self.residual_connections = residual_connections self.layer_stacking = layer_stacking self.input_embedding_dim = embedding_dim - node_features = embedding_dim - - + node_features = embedding_dim + layers = [] self.layer_norms = torch.nn.ModuleList() for _ in range(depth): layers += [ GeometricLayer[gnn_type]["Layer"]( - node_features, - int(hidden_features / GeometricLayer[gnn_type]["const_args"]["heads"]) - if GeometricLayer[gnn_type]["divide_output_dim"] - else hidden_features, **GeometricLayer[gnn_type]["const_args"])] + node_features, + int( + hidden_features + / GeometricLayer[gnn_type]["const_args"]["heads"] + ) + if GeometricLayer[gnn_type]["divide_output_dim"] + else hidden_features, + **GeometricLayer[gnn_type]["const_args"], + ) + ] if layer_norm: self.layer_norms.append(torch.nn.LayerNorm(hidden_features)) node_features = hidden_features self.graph_layers = torch.nn.ModuleList(layers) - def forward(self, batch): # Initialize node embeddings X = batch["node_embedding"] @@ -98,7 +114,10 @@ def forward(self, batch): stacked_embeddings = [X] if self.layer_stacking else [] # Apply graph layers - batch_args = {key: batch[value] for key, value in GeometricLayer[self.gnn_type]["batch_args"].items()} + batch_args = { + key: batch[value] + for key, value in GeometricLayer[self.gnn_type]["batch_args"].items() + } for i, layer in enumerate(self.graph_layers): X_skip = X X = layer(X, **batch_args) @@ -122,4 +141,6 @@ def get_embedding_dimension(self): if self.input_embedding_dim is None: raise ValueError("embedding_dim must be provided when depth=0.") return self.input_embedding_dim - return self.graph_layers[-1].out_channels * (len(self.graph_layers) + 1 if self.layer_stacking else 1) \ No newline at end of file + return self.graph_layers[-1].out_channels * ( + len(self.graph_layers) + 1 if self.layer_stacking else 1 + ) diff --git a/fiora/GNN/GNNCustomLayers.py b/fiora/GNN/GNNCustomLayers.py index 69bc9e1..e56d221 100644 --- a/fiora/GNN/GNNCustomLayers.py +++ b/fiora/GNN/GNNCustomLayers.py @@ -1,5 +1,6 @@ import torch + class GCNLayer(torch.nn.Module): def __init__(self, in_features, out_features, bias=True) -> None: super().__init__() @@ -9,7 +10,5 @@ def __init__(self, in_features, out_features, bias=True) -> None: def forward(self, X, A): HW = self.W1(X) - AHW = torch.bmm(A, self.W2(X)) #A @ self.W2(X) + AHW = torch.bmm(A, self.W2(X)) # A @ self.W2(X) return self.activation(torch.add(HW, AHW)) - - diff --git a/fiora/GNN/GraphPropertyPredictor.py b/fiora/GNN/GraphPropertyPredictor.py index 2287654..6e386ff 100644 --- a/fiora/GNN/GraphPropertyPredictor.py +++ b/fiora/GNN/GraphPropertyPredictor.py @@ -5,22 +5,37 @@ class GraphPropertyPredictor(torch.nn.Module): - def __init__(self, hidden_features: int, static_features: int, out_dimension: int, dense_depth: int=0, dense_dim: int=None, residual_connections: bool=False, pooling_func: Literal["avg", "max"]="avg", input_dropout: float=0, latent_dropout: float=0) -> None: - ''' Initialize the GraphPropertyPredictor model. - Args: - hidden_features (int): Number of hidden features for each layer. - static_features (int): Number of static features to be concatenated. - out_dimension (int): Output dimension of the model. - dense_depth (int, optional): Number of dense layers. Defaults to 0. - dense_dim (int, optional): Dimension of the dense layers. If None, it will be set to the number of input features. Defaults to None. - residual_connections (bool, optional): Whether to use residual connections. Defaults to False. - input_dropout (float, optional): Dropout rate for input features. Defaults to 0. - latent_dropout (float, optional): Dropout rate for latent features. Defaults to 0. - ''' + def __init__( + self, + hidden_features: int, + static_features: int, + out_dimension: int, + dense_depth: int = 0, + dense_dim: int = None, + residual_connections: bool = False, + pooling_func: Literal["avg", "max"] = "avg", + input_dropout: float = 0, + latent_dropout: float = 0, + ) -> None: + """Initialize the GraphPropertyPredictor model. + Args: + hidden_features (int): Number of hidden features for each layer. + static_features (int): Number of static features to be concatenated. + out_dimension (int): Output dimension of the model. + dense_depth (int, optional): Number of dense layers. Defaults to 0. + dense_dim (int, optional): Dimension of the dense layers. If None, it will be set to the number of input features. Defaults to None. + residual_connections (bool, optional): Whether to use residual connections. Defaults to False. + input_dropout (float, optional): Dropout rate for input features. Defaults to 0. + latent_dropout (float, optional): Dropout rate for latent features. Defaults to 0. + """ super().__init__() self.activation = torch.nn.ELU() - self.pooling_func = geom_nn.global_mean_pool if pooling_func == "avg" else geom_nn.global_max_pool + self.pooling_func = ( + geom_nn.global_mean_pool + if pooling_func == "avg" + else geom_nn.global_max_pool + ) self.input_dropout = torch.nn.Dropout(input_dropout) self.latent_dropout = torch.nn.Dropout(latent_dropout) self.residual_connections = residual_connections @@ -29,25 +44,29 @@ def __init__(self, hidden_features: int, static_features: int, out_dimension: in num_features = hidden_features + static_features hidden_dimension = dense_dim if dense_dim is not None else num_features if hidden_dimension != num_features and residual_connections: - raise NotImplementedError("Residual connections require the hidden dimension to match the input dimension.") + raise NotImplementedError( + "Residual connections require the hidden dimension to match the input dimension." + ) for _ in range(dense_depth): dense_layers += [torch.nn.Linear(num_features, hidden_dimension)] num_features = hidden_dimension self.dense_layers = torch.nn.ModuleList(dense_layers) - + self.output_layer = torch.nn.Linear(num_features, out_dimension) - + def forward(self, X, batch, covariate_tag="static_graph_features"): X = self.pooling_func(X, batch["batch"]) - X = torch.cat([X, batch[covariate_tag]], axis=-1) # self.input_dropout(batch["static_graph_features"]) - + X = torch.cat( + [X, batch[covariate_tag]], axis=-1 + ) # self.input_dropout(batch["static_graph_features"]) + for layer in self.dense_layers: X_skip = X X = self.activation(layer(X)) X = self.latent_dropout(X) if self.residual_connections: X = X + X_skip - + logits = self.output_layer(X) - - return logits \ No newline at end of file + + return logits diff --git a/fiora/GNN/Losses.py b/fiora/GNN/Losses.py index 4923a0c..1e0d158 100644 --- a/fiora/GNN/Losses.py +++ b/fiora/GNN/Losses.py @@ -1,20 +1,20 @@ import torch from torch import Tensor from torchmetrics import Metric -from torchmetrics.utilities import dim_zero_cat ### -# Weighted Mean Squared Error +# Weighted Mean Squared Error ### class WeightedMSELoss(torch.nn.Module): def __init__(self): super(WeightedMSELoss, self).__init__() def forward(self, input, target, weight): - loss = (weight * (input - target) ** 2) + loss = weight * (input - target) ** 2 return loss.mean() + class WeightedMSEMetric(Metric): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -28,17 +28,19 @@ def update(self, preds: Tensor, target: Tensor, weight: Tensor) -> None: def compute(self) -> Tensor: return self.sum / self.numel + ### -# Weighted Mean Absolute Error +# Weighted Mean Absolute Error ### class WeightedMAELoss(torch.nn.Module): def __init__(self): super(WeightedMAELoss, self).__init__() def forward(self, input, target, weight): - loss = (weight * torch.abs(input - target)) + loss = weight * torch.abs(input - target) return loss.mean() + class WeightedMAEMetric(Metric): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -52,30 +54,44 @@ def update(self, preds: Tensor, target: Tensor, weight: Tensor) -> None: def compute(self) -> Tensor: return self.sum / self.numel - + class GraphwiseKLLoss(torch.nn.Module): requires_segment_ptr = True - def __init__(self, eps: float = 1e-8, reduction: str = "mean", normalize_targets: bool = True, normalize_pred: bool = True): + def __init__( + self, + eps: float = 1e-8, + reduction: str = "mean", + normalize_targets: bool = True, + normalize_pred: bool = True, + ): super().__init__() self.eps = eps self.reduction = reduction self.normalize_targets = normalize_targets self.normalize_pred = normalize_pred - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, segment_ptr: torch.Tensor, weight: torch.Tensor = None): - assert segment_ptr.dim() == 1 and segment_ptr.numel() >= 2, "segment_ptr must be 1D with at least 2 entries" + def forward( + self, + y_pred: torch.Tensor, + y_true: torch.Tensor, + segment_ptr: torch.Tensor, + weight: torch.Tensor = None, + ): + assert segment_ptr.dim() == 1 and segment_ptr.numel() >= 2, ( + "segment_ptr must be 1D with at least 2 entries" + ) num_graphs = segment_ptr.numel() - 1 total = y_pred.new_tensor(0.0) total_el = 0 for g in range(num_graphs): l = segment_ptr[g].item() - r = segment_ptr[g+1].item() + r = segment_ptr[g + 1].item() q = y_pred[l:r].clamp_min(self.eps) if self.normalize_pred: - q = q / q.sum().clamp_min(self.eps) # normalize q per-graph + q = q / q.sum().clamp_min(self.eps) # normalize q per-graph p = y_true[l:r].clamp_min(0.0) if weight is not None: @@ -83,11 +99,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, segment_ptr: torch p = p * w if self.normalize_targets or p.sum() <= 0: - p = p / p.sum().clamp_min(self.eps) # normalize p per-graph + p = p / p.sum().clamp_min(self.eps) # normalize p per-graph kl = (p * (p.clamp_min(self.eps).log() - q.log())).sum() total = total + kl - total_el += (r - l) + total_el += r - l if self.reduction == "sum": return total @@ -96,27 +112,51 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, segment_ptr: torch else: return total / max(num_graphs, 1) + class GraphwiseKLLossMetric(Metric): - def __init__(self, eps: float = 1e-8, reduction: str = "mean", normalize_targets: bool = True, normalize_pred: bool = True, **kwargs): + def __init__( + self, + eps: float = 1e-8, + reduction: str = "mean", + normalize_targets: bool = True, + normalize_pred: bool = True, + **kwargs, + ): super().__init__(**kwargs) self.eps = eps self.reduction = reduction self.normalize_targets = normalize_targets self.normalize_pred = normalize_pred self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total_graphs", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") - self.add_state("total_elements", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor, segment_ptr: Tensor = None, weight: Tensor = None) -> None: + self.add_state( + "total_graphs", + default=torch.tensor(0, dtype=torch.long), + dist_reduce_fx="sum", + ) + self.add_state( + "total_elements", + default=torch.tensor(0, dtype=torch.long), + dist_reduce_fx="sum", + ) + + def update( + self, + preds: Tensor, + target: Tensor, + segment_ptr: Tensor = None, + weight: Tensor = None, + ) -> None: if segment_ptr is None: - segment_ptr = torch.tensor([0, preds.numel()], device=preds.device, dtype=torch.long) + segment_ptr = torch.tensor( + [0, preds.numel()], device=preds.device, dtype=torch.long + ) num_graphs = segment_ptr.numel() - 1 total = preds.new_tensor(0.0) total_el = 0 for g in range(num_graphs): l = segment_ptr[g].item() - r = segment_ptr[g+1].item() + r = segment_ptr[g + 1].item() if r <= l: continue @@ -133,7 +173,7 @@ def update(self, preds: Tensor, target: Tensor, segment_ptr: Tensor = None, weig kl = (p * (p.clamp_min(self.eps).log() - q.log())).sum() total = total + kl - total_el += (r - l) + total_el += r - l self.total += total self.total_graphs += torch.tensor(num_graphs, device=self.total_graphs.device) diff --git a/fiora/GNN/MLP.py b/fiora/GNN/MLP.py index cf760eb..14bbbde 100644 --- a/fiora/GNN/MLP.py +++ b/fiora/GNN/MLP.py @@ -1,20 +1,30 @@ import torch + class MLP(torch.nn.Module): - def __init__(self, input_features: int, output_dimension: int, hidden_dimension: int, hidden_layers: int = 1): + def __init__( + self, + input_features: int, + output_dimension: int, + hidden_dimension: int, + hidden_layers: int = 1, + ): super(MLP, self).__init__() self.input_features = input_features layers = [] for _ in range(hidden_layers): - layers += [torch.nn.Linear(input_features, hidden_dimension), torch.nn.ReLU()] + layers += [ + torch.nn.Linear(input_features, hidden_dimension), + torch.nn.ReLU(), + ] input_features = hidden_dimension self.hidden_layer_sequence = torch.nn.ModuleList(layers) - + self.output_layer = torch.nn.Linear(input_features, output_dimension) def forward(self, x): for layer in self.hidden_layer_sequence: x = layer(x) logits = self.output_layer(x) - return logits \ No newline at end of file + return logits diff --git a/fiora/GNN/MLPEdgeClassifier.py b/fiora/GNN/MLPEdgeClassifier.py index f969e18..9888823 100644 --- a/fiora/GNN/MLPEdgeClassifier.py +++ b/fiora/GNN/MLPEdgeClassifier.py @@ -1,20 +1,22 @@ import torch -import numpy as np + class MLPEdgeClassifier(torch.nn.Module): def __init__(self, num_node_features, num_edge_features, num_status_features): super(MLPEdgeClassifier, self).__init__() - self.num_input_features = int(2 * num_node_features + num_edge_features + num_status_features) + self.num_input_features = int( + 2 * num_node_features + num_edge_features + num_status_features + ) self.size = self.num_input_features self.hidden_layer_sequence = torch.nn.Sequential( torch.nn.Linear(self.num_input_features, 20, bias=True), torch.nn.ReLU(), ) - self.output_layer = torch.nn.Sequential( - torch.nn.Linear(20,1), + self.output_layer = torch.nn.Sequential( + torch.nn.Linear(20, 1), ) def forward(self, x, bias=True): x = self.hidden_layer_sequence(x) logits = self.output_layer(x) - return logits \ No newline at end of file + return logits diff --git a/fiora/GNN/SpectralTrainer.py b/fiora/GNN/SpectralTrainer.py index 1d2c815..4a63d76 100644 --- a/fiora/GNN/SpectralTrainer.py +++ b/fiora/GNN/SpectralTrainer.py @@ -1,145 +1,288 @@ import numpy as np import torch -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader import torch_geometric.loader as geom_loader -from torchmetrics import Accuracy, MetricTracker, MetricCollection, Precision, Recall, PrecisionRecallCurve, MeanSquaredError, MeanAbsoluteError, R2Score -from sklearn.model_selection import train_test_split -from typing import Literal, List, Callable, Any, Dict +from torchmetrics import ( + MetricTracker, + MetricCollection, +) +from typing import Literal, List, Any, Dict from fiora.GNN.Trainer import Trainer -from fiora.GNN.Datasets import collate_graph_batch, collate_graph_edge_batch from fiora.GNN.Losses import WeightedMSELoss, WeightedMAELoss -''' +""" GNN Trainer -''' +""" + class SpectralTrainer(Trainer): - def __init__(self, data: Any, train_val_split: float= 0.8, split_by_group: bool=False, only_training: bool=False, train_keys: List[int]=[], val_keys: List[int]=[], y_tag: str="y", metric_dict: Dict=None, problem_type: Literal["classification", "regression", "softmax_regression"]="classification", library: Literal["standard", "geometric"]="geometric", num_workers: int=0, seed: int=42, device: str="cpu"): - - super().__init__(data, train_val_split, split_by_group, only_training, train_keys, val_keys, seed, num_workers, device) + def __init__( + self, + data: Any, + train_val_split: float = 0.8, + split_by_group: bool = False, + only_training: bool = False, + train_keys: List[int] = [], + val_keys: List[int] = [], + y_tag: str = "y", + metric_dict: Dict = None, + problem_type: Literal[ + "classification", "regression", "softmax_regression" + ] = "classification", + library: Literal["standard", "geometric"] = "geometric", + num_workers: int = 0, + seed: int = 42, + device: str = "cpu", + ): + + super().__init__( + data, + train_val_split, + split_by_group, + only_training, + train_keys, + val_keys, + seed, + num_workers, + device, + ) self.y_tag = y_tag self.problem_type = problem_type - - # Initialize torch metrics based on dictionary + + # Initialize torch metrics based on dictionary if metric_dict: self.metrics = { - data_split: MetricTracker(MetricCollection({ - t: M() for t, M in metric_dict.items() - }), - maximize=False).to(device) + data_split: MetricTracker( + MetricCollection({t: M() for t, M in metric_dict.items()}), + maximize=False, + ).to(device) for data_split in ["train", "val", "masked_val", "test"] } else: self.metrics = self._get_default_metrics(problem_type) - self.loader_base = geom_loader.DataLoader if library == "geometric" else DataLoader - - def _training_loop(self, model, dataloader, optimizer, loss_fn, metrics, with_weights=False, with_RT=False, with_CCS=False, rt_metric=False, title=""): + self.loader_base = ( + geom_loader.DataLoader if library == "geometric" else DataLoader + ) + + def _training_loop( + self, + model, + dataloader, + optimizer, + loss_fn, + metrics, + with_weights=False, + with_RT=False, + with_CCS=False, + rt_metric=False, + title="", + ): training_loss = 0 - metrics.increment() + metrics.increment() for id, batch in enumerate(dataloader): - # Feed forward model.train() - + y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS) - kwargs={} + kwargs = {} if with_weights: - kwargs={"weight": batch["weight_tensor"]} + kwargs = {"weight": batch["weight_tensor"]} if getattr(loss_fn, "requires_segment_ptr", False): kwargs["segment_ptr"] = y_pred.get("segment_ptr") - + # Compute loss - loss = loss_fn(y_pred["fragment_probs"], batch[self.y_tag], **kwargs) # with logits + loss = loss_fn( + y_pred["fragment_probs"], batch[self.y_tag], **kwargs + ) # with logits if not rt_metric: - metrics(y_pred["fragment_probs"], batch[self.y_tag], **kwargs) # call update + metrics( + y_pred["fragment_probs"], batch[self.y_tag], **kwargs + ) # call update # Add RT and CCS to loss if with_RT: if with_weights: kwargs["weight"] = batch["weight"][batch["retention_mask"]] - loss_rt = loss_fn(y_pred["rt"][batch["retention_mask"]], batch["retention_time"][batch["retention_mask"]], **kwargs) + loss_rt = loss_fn( + y_pred["rt"][batch["retention_mask"]], + batch["retention_time"][batch["retention_mask"]], + **kwargs, + ) loss = loss + loss_rt - + if with_CCS: if with_weights: kwargs["weight"] = batch["weight"][batch["ccs_mask"]] - loss_ccs = loss_fn(y_pred["ccs"][batch["ccs_mask"]], batch["ccs"][batch["ccs_mask"]], **kwargs) + loss_ccs = loss_fn( + y_pred["ccs"][batch["ccs_mask"]], + batch["ccs"][batch["ccs_mask"]], + **kwargs, + ) loss = loss + loss_ccs if rt_metric: - metrics(y_pred["rt"][batch["retention_mask"]], batch["retention_time"][batch["retention_mask"]], **kwargs) # call update - metrics(y_pred["ccs"][batch["ccs_mask"]], batch["ccs"][batch["ccs_mask"]], **kwargs) # call update - + metrics( + y_pred["rt"][batch["retention_mask"]], + batch["retention_time"][batch["retention_mask"]], + **kwargs, + ) # call update + metrics( + y_pred["ccs"][batch["ccs_mask"]], + batch["ccs"][batch["ccs_mask"]], + **kwargs, + ) # call update + # Backpropagate optimizer.zero_grad() loss.backward() - optimizer.step() + optimizer.step() # End of training cycle: Evaluation stats = metrics.compute() training_loss /= len(dataloader) - + if self.problem_type == "classification": - print(f'{title} Training Accuracy: {stats["acc"]:>.3f} (Loss per batch: {"NOT TRACKED"})', end='\r') + print( + f"{title} Training Accuracy: {stats['acc']:>.3f} (Loss per batch: {'NOT TRACKED'})", + end="\r", + ) else: - print(f'{title} RMSE: {torch.sqrt(stats["mse"]):>.4f}', end='\r') #MSE: {stats["mse"]:>.3f}; MAE: {stats["mae"]:>.3f} + print( + f"{title} RMSE: {torch.sqrt(stats['mse']):>.4f}", end="\r" + ) # MSE: {stats["mse"]:>.3f}; MAE: {stats["mae"]:>.3f} return stats - - def _validation_loop(self, model, dataloader, loss_fn, metrics, with_weights=False, with_RT=False, with_CCS=False, rt_metric=False, mask_name=None, title="Validation"): + def _validation_loop( + self, + model, + dataloader, + loss_fn, + metrics, + with_weights=False, + with_RT=False, + with_CCS=False, + rt_metric=False, + mask_name=None, + title="Validation", + ): metrics.increment() with torch.no_grad(): for id, batch in enumerate(dataloader): - model.eval() - y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS) - if mask_name: - loss = loss_fn(y_pred["fragment_probs"][batch[mask_name]], batch[self.y_tag][batch[mask_name]]) - metrics.update(y_pred["fragment_probs"][batch[mask_name]], batch[self.y_tag][batch[mask_name]]) - else: - kwargs={} - if with_weights: - kwargs={"weight": batch["weight_tensor"]} - if getattr(loss_fn, "requires_segment_ptr", False): - kwargs["segment_ptr"] = y_pred.get("segment_ptr") - loss = loss_fn(y_pred["fragment_probs"], batch[self.y_tag], **kwargs) - if not rt_metric: - metrics.update(y_pred["fragment_probs"], batch[self.y_tag], **kwargs) - if rt_metric: - metrics(y_pred["rt"][batch["retention_mask"]], batch["retention_time"][batch["retention_mask"]], **kwargs) # call update - metrics(y_pred["ccs"][batch["ccs_mask"]], batch["ccs"][batch["ccs_mask"]], **kwargs) # call update - + model.eval() + y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS) + if mask_name: + metrics.update( + y_pred["fragment_probs"][batch[mask_name]], + batch[self.y_tag][batch[mask_name]], + ) + else: + kwargs = {} + if with_weights: + kwargs = {"weight": batch["weight_tensor"]} + if getattr(loss_fn, "requires_segment_ptr", False): + kwargs["segment_ptr"] = y_pred.get("segment_ptr") + if not rt_metric: + metrics.update( + y_pred["fragment_probs"], batch[self.y_tag], **kwargs + ) + if rt_metric: + metrics( + y_pred["rt"][batch["retention_mask"]], + batch["retention_time"][batch["retention_mask"]], + **kwargs, + ) # call update + metrics( + y_pred["ccs"][batch["ccs_mask"]], + batch["ccs"][batch["ccs_mask"]], + **kwargs, + ) # call update + # End of Validation cycle stats = metrics.compute() - print(f'\t{title} RMSE: {torch.sqrt(stats["mse"]):>.4f}') + print(f"\t{title} RMSE: {torch.sqrt(stats['mse']):>.4f}") return stats - - + # Training function - def train(self, model, optimizer, loss_fn, scheduler=None, batch_size=16, epochs=2, val_every_n_epochs=1, use_validation_mask=False, with_RT=True, with_CCS=True, rt_metric=False, mask_name="validation_mask", tag="") -> Dict[str, Any]: - + def train( + self, + model, + optimizer, + loss_fn, + scheduler=None, + batch_size=16, + epochs=2, + val_every_n_epochs=1, + use_validation_mask=False, + with_RT=True, + with_CCS=True, + rt_metric=False, + mask_name="validation_mask", + save_path: str | None = None, + tag="", + ) -> Dict[str, Any]: + # Set up checkpoint system and model info - self._init_checkpoint_system(save_path=f"../../checkpoint_{tag}.best.pt") + if save_path is None: + save_path = f"../../checkpoint_{tag}.best.pt" + self._init_checkpoint_system(save_path=save_path) self._init_history() model.model_params["training_label"] = self.y_tag - + # Stage data into dataloader - training_loader = self.loader_base(self.training_data, batch_size=batch_size, num_workers=self.num_workers, shuffle=True) + training_loader = self.loader_base( + self.training_data, + batch_size=batch_size, + num_workers=self.num_workers, + shuffle=True, + ) if not self.only_training: - validation_loader = self.loader_base(self.validation_data, batch_size=batch_size, num_workers=self.num_workers, shuffle=True) - using_weighted_loss_func = isinstance(loss_fn, WeightedMSELoss) | isinstance(loss_fn, WeightedMAELoss) - + validation_loader = self.loader_base( + self.validation_data, + batch_size=batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + using_weighted_loss_func = isinstance(loss_fn, WeightedMSELoss) | isinstance( + loss_fn, WeightedMAELoss + ) + # Main loop for e in range(epochs): - # Training - train_stats = self._training_loop(model, training_loader, optimizer, loss_fn, self.metrics["train"], title=f'Epoch {e + 1}/{epochs}: ', with_weights=using_weighted_loss_func, with_RT=with_RT, with_CCS=with_CCS, rt_metric=rt_metric) - + train_stats = self._training_loop( + model, + training_loader, + optimizer, + loss_fn, + self.metrics["train"], + title=f"Epoch {e + 1}/{epochs}: ", + with_weights=using_weighted_loss_func, + with_RT=with_RT, + with_CCS=with_CCS, + rt_metric=rt_metric, + ) + # Validation - is_val_cycle = not self.only_training and ((e + 1) % val_every_n_epochs == 0) - if is_val_cycle: - val_stats = self._validation_loop(model, validation_loader, loss_fn, self.metrics["masked_val"] if use_validation_mask else self.metrics["val"], with_weights=using_weighted_loss_func, with_RT=with_RT, with_CCS=with_CCS, rt_metric=rt_metric, mask_name=mask_name if use_validation_mask else None, title="Masked Validation" if use_validation_mask else "Validation") + is_val_cycle = not self.only_training and ( + (e + 1) % val_every_n_epochs == 0 + ) + if is_val_cycle: + val_stats = self._validation_loop( + model, + validation_loader, + loss_fn, + self.metrics["masked_val"] + if use_validation_mask + else self.metrics["val"], + with_weights=using_weighted_loss_func, + with_RT=with_RT, + with_CCS=with_CCS, + rt_metric=rt_metric, + mask_name=mask_name if use_validation_mask else None, + title="Masked Validation" if use_validation_mask else "Validation", + ) # End of epoch: Advance scheduler if scheduler: @@ -148,21 +291,31 @@ def train(self, model, optimizer, loss_fn, scheduler=None, batch_size=16, epochs if is_val_cycle: scheduler.step(torch.sqrt(val_stats["mse"])) if scheduler.get_last_lr()[0] < last_lr: - print(f"\t >> Learning rate reduced from {last_lr:1.0e} to {scheduler.get_last_lr()[0]:1.0e}") + print( + f"\t >> Learning rate reduced from {last_lr:1.0e} to {scheduler.get_last_lr()[0]:1.0e}" + ) else: scheduler.step() - - # Save history - if is_val_cycle: + if is_val_cycle: # Update checkpoint if val_stats["mse"].tolist() < self.checkpoint_stats["val_loss"]: - self._update_checkpoint({"epoch": e+1, "val_loss": val_stats["mse"].tolist(), "sqrt_val_loss": torch.sqrt(val_stats["mse"]).tolist()}, model) - print(f"\t >> Set new checkpoint to epoch {e+1}") - self._update_history(e+1, train_stats, val_stats, lr=scheduler.get_last_lr()[0]) + self._update_checkpoint( + { + "epoch": e + 1, + "val_loss": val_stats["mse"].tolist(), + "sqrt_val_loss": torch.sqrt(val_stats["mse"]).tolist(), + }, + model, + ) + print(f"\t >> Set new checkpoint to epoch {e + 1}") + current_lr = ( + scheduler.get_last_lr()[0] + if scheduler is not None + else optimizer.param_groups[0]["lr"] + ) + self._update_history(e + 1, train_stats, val_stats, lr=current_lr) - print("Finished Training!") return self.checkpoint_stats - \ No newline at end of file diff --git a/fiora/GNN/Trainer.py b/fiora/GNN/Trainer.py index e0207d6..17f5943 100644 --- a/fiora/GNN/Trainer.py +++ b/fiora/GNN/Trainer.py @@ -1,20 +1,38 @@ from abc import ABC, abstractmethod import torch import numpy as np -from torch.utils.data import DataLoader, Dataset -from torchmetrics import Accuracy, MetricTracker, MetricCollection, Precision, Recall, PrecisionRecallCurve, MeanSquaredError, MeanAbsoluteError, R2Score +from torch.utils.data import Dataset +from torchmetrics import ( + Accuracy, + MetricTracker, + MetricCollection, + Precision, + Recall, + MeanSquaredError, + MeanAbsoluteError, +) from sklearn.model_selection import train_test_split from typing import Literal, List, Dict, Any class Trainer(ABC): - def __init__(self, data: Any, train_val_split: float=0.8, split_by_group: bool=False, only_training: bool=False, - train_keys: List[int]=[], val_keys: List[int]=[], seed: int=42, num_workers: int=0, device: str="cpu") -> None: - + def __init__( + self, + data: Any, + train_val_split: float = 0.8, + split_by_group: bool = False, + only_training: bool = False, + train_keys: List[int] = [], + val_keys: List[int] = [], + seed: int = 42, + num_workers: int = 0, + device: str = "cpu", + ) -> None: + self.only_training = only_training self.num_workers = num_workers self.device = device - + if only_training: self.training_data = data self.validation_data = Dataset() @@ -23,12 +41,19 @@ def __init__(self, data: Any, train_val_split: float=0.8, split_by_group: bool=F else: train_size = int(len(data) * train_val_split) self.training_data, self.validation_data = torch.utils.data.random_split( - data, [train_size, len(data) - train_size], - generator=torch.Generator().manual_seed(seed) - ) + data, + [train_size, len(data) - train_size], + generator=torch.Generator().manual_seed(seed), + ) - - def _split_by_group(self, data, train_val_split: float, train_keys: List[int], val_keys: List[int], seed: int): + def _split_by_group( + self, + data, + train_val_split: float, + train_keys: List[int], + val_keys: List[int], + seed: int, + ): group_ids = [getattr(x, "group_id") for x in data] keys = np.unique(group_ids) if len(train_keys) > 0 and len(val_keys) > 0: @@ -43,34 +68,43 @@ def _split_by_group(self, data, train_val_split: float, train_keys: List[int], v self.training_data = torch.utils.data.Subset(data, train_ids) self.validation_data = torch.utils.data.Subset(data, val_ids) - def _get_default_metrics(self, problem_type: Literal["classification", "regression", "softmax_regression"]): + def _get_default_metrics( + self, + problem_type: Literal["classification", "regression", "softmax_regression"], + ): metrics = { - data_split: MetricTracker(MetricCollection( - { - 'acc': Accuracy("binary", num_classes=1), - 'prec': Precision('binary', num_classes=1), - 'rec': Recall('binary', num_classes=1) - }) if problem_type=="classification" else MetricCollection( - { - 'mse': MeanSquaredError(), - 'mae': MeanAbsoluteError() - })).to(self.device) - for data_split in ["train", "val", "masked_val", "test"] - } - + data_split: MetricTracker( + MetricCollection( + { + "acc": Accuracy("binary", num_classes=1), + "prec": Precision("binary", num_classes=1), + "rec": Recall("binary", num_classes=1), + } + ) + if problem_type == "classification" + else MetricCollection( + {"mse": MeanSquaredError(), "mae": MeanAbsoluteError()} + ) + ).to(self.device) + for data_split in ["train", "val", "masked_val", "test"] + } + return metrics - + def _init_checkpoint_system(self, save_path: str) -> None: self.checkpoint_stats = { "epoch": -1, "val_loss": 100000.0, "sqrt_val_loss": 100000.0, - "file": save_path} + "file": save_path, + } - def _update_checkpoint(self, new_checkpoint_data: Dict[str, Any], model, save_checkpoint: bool=True) -> None: + def _update_checkpoint( + self, new_checkpoint_data: Dict[str, Any], model, save_checkpoint: bool = True + ) -> None: self.checkpoint_stats.update(new_checkpoint_data) model.save(self.checkpoint_stats["file"]) - + def _init_history(self) -> None: self.history = { "epoch": [], @@ -78,7 +112,7 @@ def _init_history(self) -> None: "sqrt_train_error": [], "val_error": [], "sqrt_val_error": [], - "lr": [] + "lr": [], } def _update_history(self, epoch, train_stats, val_stats, lr) -> None: @@ -88,12 +122,12 @@ def _update_history(self, epoch, train_stats, val_stats, lr) -> None: self.history["val_error"].append(val_stats["mse"]) self.history["sqrt_val_error"].append(torch.sqrt(val_stats["mse"]).tolist()) self.history["lr"].append(lr) - + def is_group_in_training_set(self, group_id): - return (group_id in self.train_keys) - + return group_id in self.train_keys + def is_group_in_validation_set(self, group_id): - return (group_id in self.val_keys) + return group_id in self.val_keys @abstractmethod def _training_loop(self, model, dataloader, optimizer, loss_fn, **kwargs): diff --git a/fiora/IO/LibraryLoader.py b/fiora/IO/LibraryLoader.py index 1603cac..7b863a6 100644 --- a/fiora/IO/LibraryLoader.py +++ b/fiora/IO/LibraryLoader.py @@ -1,16 +1,17 @@ import pandas as pd -class LibraryLoader(): - def __init__(self, path=None): - self.path = path - - def load_from_csv(self, path): - return pd.read_csv(path, index_col=[0], low_memory=False) - - def load_from_msp(self): - #TODO IMPLEMENT - return - - def clean_library(self): - #TODO IMPLEMENT + parameters for filtration - return \ No newline at end of file + +class LibraryLoader: + def __init__(self, path=None): + self.path = path + + def load_from_csv(self, path): + return pd.read_csv(path, index_col=[0], low_memory=False) + + def load_from_msp(self): + # TODO IMPLEMENT + return + + def clean_library(self): + # TODO IMPLEMENT + parameters for filtration + return diff --git a/fiora/IO/cfmReader.py b/fiora/IO/cfmReader.py index 3bedeb0..e3ddfca 100644 --- a/fiora/IO/cfmReader.py +++ b/fiora/IO/cfmReader.py @@ -1,23 +1,30 @@ import pandas as pd -def read(source, sep: str=" ", as_df=False): - file = open(source, 'r') + +def read(source, sep: str = " ", as_df=False): + file = open(source, "r") data = [] data_piece = {} precursor = "" mz, intensity, annotation = [], [], [] + energy2 = False for line in file: - if line == '\n': + if line == "\n": if energy2: - data_piece['peaks40'] = {'mz': mz, 'intensity': intensity, 'annotation': annotation} + data_piece["peaks40"] = { + "mz": mz, + "intensity": intensity, + "annotation": annotation, + } mz, intensity, annotation = [], [], [] energy2 = False continue else: continue - if line.startswith("#PREDICTED"): continue + if line.startswith("#PREDICTED"): + continue if line.startswith("#In-silico"): precursor = line.split("ESI-MS/MS ")[1].split(" Spectra")[0] continue @@ -26,19 +33,27 @@ def read(source, sep: str=" ", as_df=False): data.append(data_piece) data_piece, mz, intensity, annotation = {}, [], [], [] data_piece["Precursor_type"] = precursor - if '=' in line: - key = line.split('=')[0] - value = "=".join(line.strip().split('=', 1)[1:]) + if "=" in line: + key = line.split("=")[0] + value = "=".join(line.strip().split("=", 1)[1:]) data_piece[key] = value elif line.strip() == "energy0": continue elif line.strip() == "energy1": - data_piece['peaks10'] = {'mz': mz, 'intensity': intensity, 'annotation': annotation} + data_piece["peaks10"] = { + "mz": mz, + "intensity": intensity, + "annotation": annotation, + } mz, intensity, annotation = [], [], [] # new data piece elif line.strip() == "energy2": energy2 = True - data_piece['peaks20'] = {'mz': mz, 'intensity': intensity, 'annotation': annotation} + data_piece["peaks20"] = { + "mz": mz, + "intensity": intensity, + "annotation": annotation, + } mz, intensity, annotation = [], [], [] else: line_split = line.split(sep) @@ -52,6 +67,3 @@ def read(source, sep: str=" ", as_df=False): return pd.DataFrame(data[1:]) else: return data[1:] - - - diff --git a/fiora/IO/fraggraphReader.py b/fiora/IO/fraggraphReader.py index 74b8bc6..ebfa59e 100644 --- a/fiora/IO/fraggraphReader.py +++ b/fiora/IO/fraggraphReader.py @@ -1,19 +1,20 @@ import pandas as pd - # from https://github.com/hcji/PyCFMID/blob/master/PyCFMID/PyCFMID.py def parser_fraggraph_gen(output_file): with open(output_file) as t: output = t.readlines() - output = [s.replace('\n', '') for s in output] + output = [s.replace("\n", "") for s in output] nfrags = int(output[0]) - frag_index = [int(output[i].split(' ')[0]) for i in range(1, nfrags+1)] - frag_mass = [float(output[i].split(' ')[1]) for i in range(1, nfrags+1)] - frag_smiles = [output[i].split(' ')[2] for i in range(1, nfrags+1)] - loss_from = [int(output[i].split(' ')[0]) for i in range(nfrags+2, len(output))] - loss_to = [int(output[i].split(' ')[1]) for i in range(nfrags+2, len(output))] - loss_smiles = [output[i].split(' ')[2] for i in range(nfrags+2, len(output))] - fragments = pd.DataFrame({'index': frag_index, 'mass': frag_mass, 'smiles': frag_smiles}) - losses = pd.DataFrame({'from': loss_from, 'to': loss_to, 'smiles': loss_smiles}) - return {'fragments': fragments, 'losses': losses} \ No newline at end of file + frag_index = [int(output[i].split(" ")[0]) for i in range(1, nfrags + 1)] + frag_mass = [float(output[i].split(" ")[1]) for i in range(1, nfrags + 1)] + frag_smiles = [output[i].split(" ")[2] for i in range(1, nfrags + 1)] + loss_from = [int(output[i].split(" ")[0]) for i in range(nfrags + 2, len(output))] + loss_to = [int(output[i].split(" ")[1]) for i in range(nfrags + 2, len(output))] + loss_smiles = [output[i].split(" ")[2] for i in range(nfrags + 2, len(output))] + fragments = pd.DataFrame( + {"index": frag_index, "mass": frag_mass, "smiles": frag_smiles} + ) + losses = pd.DataFrame({"from": loss_from, "to": loss_to, "smiles": loss_smiles}) + return {"fragments": fragments, "losses": losses} diff --git a/fiora/IO/mgfReader.py b/fiora/IO/mgfReader.py index 90881fd..95048e2 100644 --- a/fiora/IO/mgfReader.py +++ b/fiora/IO/mgfReader.py @@ -1,33 +1,41 @@ -#TODO Check if first spectrum is read +# TODO Check if first spectrum is read import pandas as pd -def read(source, sep: str=" ", as_df=False, debug=False): - file = open(source, 'r') + +def read(source, sep: str = " ", as_df=False, debug=False): + file = open(source, "r") in_begin_ions = False data = [] data_piece = {} mz, intensity, ion = [], [], [] for line in file: - if debug: print(line.strip()) - if line == "MASS=Monoisotopic\n": continue #TODO edge case hacky solution - if line == '\n': continue - if line.startswith("#"): continue - if line.startswith("NA#"): continue + if debug: + print(line.strip()) + if line == "MASS=Monoisotopic\n": + continue # TODO edge case hacky solution + if line == "\n": + continue + if line.startswith("#"): + continue + if line.startswith("NA#"): + continue if line.strip() == "END IONS": in_begin_ions = False - data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'annotation': ion} + data_piece["peaks"] = {"mz": mz, "intensity": intensity, "annotation": ion} data.append(data_piece) continue - + if line.strip() == "BEGIN IONS" or line.strip() == "BEGIN IONS:": in_begin_ions = True data_piece, mz, intensity, ion = {}, [], [], [] continue - if '=' in line: - key = line.split('=')[0] - value = "=".join(line.strip().split('=', 1)[1:]) #line.split('=', 1)[1].strip() + if "=" in line: + key = line.split("=")[0] + value = "=".join( + line.strip().split("=", 1)[1:] + ) # line.split('=', 1)[1].strip() data_piece[key] = value else: line_split = line.split(sep) @@ -45,7 +53,7 @@ def read(source, sep: str=" ", as_df=False, debug=False): def get_spectrum_by_name(source, name): - file = open(source, 'r') + file = open(source, "r") line_match = "TITLE=" + name + "\n" data_piece = {} @@ -54,7 +62,7 @@ def get_spectrum_by_name(source, name): for line in file: if line == "END IONS\n" and found: - data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'ion': ion} + data_piece["peaks"] = {"mz": mz, "intensity": intensity, "ion": ion} break if line == line_match: # exact name match found = True @@ -62,12 +70,12 @@ def get_spectrum_by_name(source, name): if not found: continue # skip ahead - if '=' in line: - key = line.split('=')[0] - value = line.split('=', 1)[1].strip() + if "=" in line: + key = line.split("=")[0] + value = line.split("=", 1)[1].strip() data_piece[key] = value else: - line_split = line.split(' ') + line_split = line.split(" ") mz.append(line_split[0].strip()) intensity.append(line_split[1].strip()) file.close() @@ -75,7 +83,7 @@ def get_spectrum_by_name(source, name): return data_piece -''' +""" Thoughts on format Every Spectrum becomes a dictionary @@ -94,4 +102,4 @@ def get_spectrum_by_name(source, name): sparse = (Name, sparse_vector) -''' +""" diff --git a/fiora/IO/mgfWriter.py b/fiora/IO/mgfWriter.py index ab02005..0d149c4 100644 --- a/fiora/IO/mgfWriter.py +++ b/fiora/IO/mgfWriter.py @@ -1,8 +1,12 @@ -import pandas as pd - - - -def write_mgf(df, path, peak_tag="peaks", write_header=True, headers=["TITLE", "RTINSECONDS", "PEPMASS", "CHARGE"], header_map={}, annotation=False): +def write_mgf( + df, + path, + peak_tag="peaks", + write_header=True, + headers=["TITLE", "RTINSECONDS", "PEPMASS", "CHARGE"], + header_map={}, + annotation=False, +): for h in headers: if h not in header_map.keys(): header_map[h] = h @@ -13,9 +17,9 @@ def write_mgf(df, path, peak_tag="peaks", write_header=True, headers=["TITLE", " if write_header: for key in headers: outfile.write(key + "=" + str(df.loc[x][header_map[key]]) + "\n") - for i in range(len(peaks['mz'])): - line = str(peaks['mz'][i]) + " " + str(peaks['intensity'][i]) + for i in range(len(peaks["mz"])): + line = str(peaks["mz"][i]) + " " + str(peaks["intensity"][i]) if annotation: - line += " " + peaks['annotation'][i] + line += " " + peaks["annotation"][i] outfile.write(line + "\n") outfile.write("END IONS\n") diff --git a/fiora/IO/molReader.py b/fiora/IO/molReader.py index c2ba5c9..e0be3db 100644 --- a/fiora/IO/molReader.py +++ b/fiora/IO/molReader.py @@ -1,12 +1,12 @@ from rdkit import Chem - -''' +""" Functions to read mol files -''' +""" + def load_MOL(path): - MOL_string=open(path,'r').read() + MOL_string = open(path, "r").read() m = Chem.MolFromMolBlock(MOL_string) - return m \ No newline at end of file + return m diff --git a/fiora/IO/mspReader.py b/fiora/IO/mspReader.py index fc7a6e7..e2bc47e 100644 --- a/fiora/IO/mspReader.py +++ b/fiora/IO/mspReader.py @@ -1,27 +1,23 @@ -#from pyteomics import mzml -#import regex as re +import regex as re -# import spectrum_utils.spectrum as sus - - -def read(source, sep=' '): - file = open(source, 'r') +def read(source, sep=" "): + file = open(source, "r") data = [] data_piece = {} mz, intensity, ion = [], [], [] for line in file: - if 'Name:' == line[0:5] or 'NAME:' == line[0:5]: - data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'annotation': ion} + if "Name:" == line[0:5] or "NAME:" == line[0:5]: + data_piece["peaks"] = {"mz": mz, "intensity": intensity, "annotation": ion} data.append(data_piece) data_piece, mz, intensity, ion = {}, [], [], [] - if ':' in line: - key = line.split(':')[0] - value = line.split(':', 1)[1].strip() + if ":" in line: + key = line.split(":")[0] + value = line.split(":", 1)[1].strip() data_piece[key] = value else: if line == "\n": @@ -30,9 +26,9 @@ def read(source, sep=' '): line_split = ls.split(sep) mz.append(float(line_split[0])) intensity.append(float(line_split[1])) - #ion.append(line_split[2].strip()) + # ion.append(line_split[2].strip()) - data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'annotation': ion} + data_piece["peaks"] = {"mz": mz, "intensity": intensity, "annotation": ion} data.append(data_piece) file.close() @@ -40,21 +36,21 @@ def read(source, sep=' '): def read_minimal(source): - file = open(source, 'r') + file = open(source, "r") data = [] data_piece = {} - mz, intensity, ion = [], [], [] + mz, intensity = [], [] for line in file: - if 'Name:' == line[0:5]: - data_piece['peaks'] = {'mz': mz, 'intensity': intensity} + if "Name:" == line[0:5]: + data_piece["peaks"] = {"mz": mz, "intensity": intensity} data.append(data_piece) - data_piece = {'Name': line.split(':', 1)[1].strip()} + data_piece = {"Name": line.split(":", 1)[1].strip()} mz, intensity = [], [] continue - if not (':' in line): - line_split = line.split('\t') + if ":" not in line: + line_split = line.split("\t") mz.append(line_split[0]) intensity.append(line_split[1]) @@ -65,35 +61,34 @@ def read_minimal(source): def read_peptides(source): - file = open(source, 'r') + file = open(source, "r") pep_list = [] for line in file: - if 'Name:' == line[0:5]: - l = line.strip('\n')[5:] - l = re.sub(r'[\d+ /]', '', l) - pep_list.append(l) + if "Name:" == line[0:5]: + li = line.strip("\n")[5:] + li = re.sub(r"[\d+ /]", "", li) + pep_list.append(li) file.close() return pep_list def read_sparse(source): - file = open(source, 'r') + file = open(source, "r") file.close() def readOld(source): - file = open(source, 'r') - c = 0 + file = open(source, "r") data = [] active_lines = [] for line in file: - if 'Name:' == line[0:5]: + if "Name:" == line[0:5]: data.append(make_data_piece(active_lines)) active_lines = [] - active_lines.append(line.strip('\n')) + active_lines.append(line.strip("\n")) data.append(make_data_piece(active_lines)) file.close() @@ -105,22 +100,22 @@ def make_data_piece(lines): mz, intensity, ion = [], [], [] for line in lines: - if ':' in line: - key = line.split(':')[0] - value = ':'.join(line.split(':')[1:]) + if ":" in line: + key = line.split(":")[0] + value = ":".join(line.split(":")[1:]) data_piece[key] = value else: - line_split = line.split('\t') + line_split = line.split("\t") mz.append(line_split[0]) intensity.append(line_split[1]) ion.append(line_split[2]) - data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'ion': ion} + data_piece["peaks"] = {"mz": mz, "intensity": intensity, "ion": ion} return data_piece def get_spectrum_by_name(source, name): - file = open(source, 'r') + file = open(source, "r") line_match = "Name: " + name + "\n" data_piece = {} @@ -129,17 +124,18 @@ def get_spectrum_by_name(source, name): for line in file: if line[0:5] == "Name:" and found: - data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'ion': ion} + data_piece["peaks"] = {"mz": mz, "intensity": intensity, "ion": ion} break - if line == line_match: #exact name match + if line == line_match: # exact name match found = True - if not found: continue - if ':' in line: - key = line.split(':')[0] - value = line.split(':', 1)[1].strip() + if not found: + continue + if ":" in line: + key = line.split(":")[0] + value = line.split(":", 1)[1].strip() data_piece[key] = value else: - line_split = line.split('\t') + line_split = line.split("\t") mz.append(line_split[0]) intensity.append(line_split[1]) ion.append(line_split[2].strip()) @@ -149,7 +145,7 @@ def get_spectrum_by_name(source, name): return data_piece -''' +""" Thoughts on format Every Spectrum becomes a dictionary @@ -168,4 +164,4 @@ def get_spectrum_by_name(source, name): sparse = (Name, sparse_vector) -''' +""" diff --git a/fiora/IO/mspWriter.py b/fiora/IO/mspWriter.py index 04aee46..72efbed 100644 --- a/fiora/IO/mspWriter.py +++ b/fiora/IO/mspWriter.py @@ -1,15 +1,32 @@ -import pandas as pd - - -def write_msp(df, path, write_header=True, headers=["Name", "Precursor_type", "Spectrum_type", "PRECURSORMZ", "RETENTIONTIME", "Charge", "Comments", "Num peaks"], annotation: bool=False): +def write_msp( + df, + path, + write_header=True, + headers=[ + "Name", + "Precursor_type", + "Spectrum_type", + "PRECURSORMZ", + "RETENTIONTIME", + "Charge", + "Comments", + "Num peaks", + ], + annotation: bool = False, +): with open(path, "w") as outfile: for x in df.index: peaks = df.loc[x].peaks if write_header: for key in headers: outfile.write(key + ": " + str(df.loc[x][key]) + "\n") - d = df.loc[x] outfile.write(f"Num peaks: {len(peaks['mz'])}\n") - for i in range(len(peaks['mz'])): + for i in range(len(peaks["mz"])): peak_annotation = f"\t{peaks['annotation'][i]}" if annotation else "" - outfile.write(str(peaks['mz'][i]) + "\t" + str(peaks['intensity'][i]) + peak_annotation + "\n") \ No newline at end of file + outfile.write( + str(peaks["mz"][i]) + + "\t" + + str(peaks["intensity"][i]) + + peak_annotation + + "\n" + ) diff --git a/fiora/IO/mspredReader.py b/fiora/IO/mspredReader.py index 2614174..60d253e 100644 --- a/fiora/IO/mspredReader.py +++ b/fiora/IO/mspredReader.py @@ -6,10 +6,10 @@ # Function was adjusted using code from https://github.com/samgoldman97/ms-pred def convert_dict_to_mz(values): - + peak_dict = defaultdict(lambda: {}) for k, val in values["frags"].items(): - masses, intens = val["mz_charge"], val['intens'] + masses, intens = val["mz_charge"], val["intens"] for m, i in zip(masses, intens): if i <= 0: continue @@ -28,30 +28,29 @@ def convert_dict_to_mz(values): k: dict(inten=v["inten"] / max_inten, frag_hash=v["frag_hash"]) for k, v in peak_dict.items() } - - + peaks = {"mz": [], "intensity": []} for k, v in peak_dict.items(): peaks["mz"].append(k) peaks["intensity"].append(v["inten"]) - + return peaks def read(dir): spectra = [] - + for file in os.listdir(dir): if file.endswith(".json"): temp_dict = {"file": file, "name": file.split(".")[0].replace("pred_", "")} try: - with open(os.path.join(dir, file), 'r') as fp: + with open(os.path.join(dir, file), "r") as fp: values = json.load(fp) peaks = convert_dict_to_mz(values) - temp_dict["peaks"] = peaks - + temp_dict["peaks"] = peaks + spectra.append(temp_dict) - except: + except Exception as _: print(f"Warning: unable to read {file}") - - return pd.DataFrame(spectra) \ No newline at end of file + + return pd.DataFrame(spectra) diff --git a/fiora/IO/mspredWriter.py b/fiora/IO/mspredWriter.py index d32270d..c8896eb 100644 --- a/fiora/IO/mspredWriter.py +++ b/fiora/IO/mspredWriter.py @@ -1,30 +1,33 @@ import os from fiora.MOL.constants import ADDUCT_WEIGHTS -label_header = ["dataset", "spec", "name", "ionization", "formula", "smiles", "inchikey"] +label_header = [ + "dataset", + "spec", + "name", + "ionization", + "formula", + "smiles", + "inchikey", +] + def write_labels(df, output_file, label_map, from_metabolite=True): if from_metabolite: df["formula"] = df["Metabolite"].apply(lambda x: x.Formula) df["smiles"] = df["Metabolite"].apply(lambda x: x.SMILES) df["inchikey"] = df["Metabolite"].apply(lambda x: x.InChIKey) - + if "dataset" not in label_map.keys(): df = df.drop(columns="dataset") try: - df.rename(columns=label_map)[label_header].to_csv(output_file, index=False, sep="\t") - except: - raise NameError(f"Failed to write labels file. Make sure file path is correct. Make sure all headers are present in DataFrame {label_header}. Use label_map to rename columns.") - -# def write_spec_files_wo_header(df, directory, spec="spec"): -# for i, row in df.iterrows(): -# output_file = os.path.join(directory, row[spec] + ".ms") -# with open(output_file, "w") as f: -# f.write("> This spectrum only containing ms 2 peaks\n") -# f.write("#No metadata\n\n") -# f.write(">ms2peaks") -# for j, mz in row["peaks"]["mz"]: -# f.write(mz + " " + row["peaks"]["intensity"][j] + "\n") + df.rename(columns=label_map)[label_header].to_csv( + output_file, index=False, sep="\t" + ) + except Exception as _: + raise NameError( + f"Failed to write labels file. Make sure file path is correct. Make sure all headers are present in DataFrame {label_header}. Use label_map to rename columns." + ) def write_spec_files(df, directory, spec_tag="spec"): @@ -32,13 +35,16 @@ def write_spec_files(df, directory, spec_tag="spec"): output_file = os.path.join(directory, str(row[spec_tag]) + ".ms") with open(output_file, "w") as f: metabolite = row["Metabolite"] - + # Write header f.write(">compound " + row["Name"] + " \n") f.write(">formula " + metabolite.Formula + " \n") - - - f.write(">parentmass " + str(metabolite.ExactMolWeight + ADDUCT_WEIGHTS[row["Precursor_type"]]) + " \n") + + f.write( + ">parentmass " + + str(metabolite.ExactMolWeight + ADDUCT_WEIGHTS[row["Precursor_type"]]) + + " \n" + ) f.write(">ionization " + row["Precursor_type"] + " \n") f.write(">InChi " + metabolite.InChI + " \n") f.write(">InChIKey " + metabolite.InChIKey + " \n") @@ -48,7 +54,7 @@ def write_spec_files(df, directory, spec_tag="spec"): f.write("#spectrumid " + str(row[spec_tag]) + " \n") f.write("#InChi " + metabolite.InChI + " \n") f.write("\n") - + # Write peaks f.write(">ms2peaks") for j, mz in enumerate(row["peaks"]["mz"]): @@ -56,14 +62,36 @@ def write_spec_files(df, directory, spec_tag="spec"): f.write(str(mz) + " " + str(row["peaks"]["intensity"][j])) -def write_dataset(df, directory, label_map = {"dataset": "dataset", "spec": "spec", "name": "name", "formula": "formula", "ionization": "ionization", "smiles": "smiles", "inchikey": "inchikey"}): - write_labels(df, output_file=os.path.join(directory, "labels.tsv"), label_map=label_map, from_metabolite=True) - write_labels(df.iloc[::-1], output_file=os.path.join(directory, "reverse_labels.tsv"), label_map=label_map, from_metabolite=True) - +def write_dataset( + df, + directory, + label_map={ + "dataset": "dataset", + "spec": "spec", + "name": "name", + "formula": "formula", + "ionization": "ionization", + "smiles": "smiles", + "inchikey": "inchikey", + }, +): + write_labels( + df, + output_file=os.path.join(directory, "labels.tsv"), + label_map=label_map, + from_metabolite=True, + ) + write_labels( + df.iloc[::-1], + output_file=os.path.join(directory, "reverse_labels.tsv"), + label_map=label_map, + from_metabolite=True, + ) + spec_tag = {v: k for k, v in label_map.items()}["spec"] spec_path = os.path.join(directory, "spec_files") if not os.path.exists(spec_path): os.mkdir(spec_path) write_spec_files(df, spec_path, spec_tag=spec_tag) - - return \ No newline at end of file + + return diff --git a/fiora/MOL/FragmentationTree.py b/fiora/MOL/FragmentationTree.py index 027c3b8..af87077 100644 --- a/fiora/MOL/FragmentationTree.py +++ b/fiora/MOL/FragmentationTree.py @@ -4,24 +4,29 @@ from rdkit import Chem from rdkit.Chem import AllChem -import rdkit.Chem.Descriptors as Descriptors -import numpy as np -from treelib import Node, Tree -from copy import copy +from treelib import Tree # TODO can a fragment be tied to more than one edge: Yes. TODO see todo case in build_frag_tree -class Fragment: +class Fragment: def __init__(self, mol, edge=None, isotope_labels=None): - + # Track edge break and reset isotope changes self.edges = [edge] if edge: - subgraph = [a.GetIsotope() for a in mol.GetAtoms()] #use isotope info as a proxy for node id - break_side = "left" if edge[0] in subgraph else "right" if edge[1] in subgraph else "unidentified" - if break_side == "unidentified": + subgraph = [ + a.GetIsotope() for a in mol.GetAtoms() + ] # use isotope info as a proxy for node id + break_side = ( + "left" + if edge[0] in subgraph + else "right" + if edge[1] in subgraph + else "unidentified" + ) + if break_side == "unidentified": print("ERROR", edge, subgraph, Chem.MolToSmiles(mol)) raise ValueError("Unidentified edge in fragment") self.break_sides = [break_side] @@ -29,30 +34,39 @@ def __init__(self, mol, edge=None, isotope_labels=None): else: self.break_sides = [None] self.subgraphs = [] - if isotope_labels: # Reset isotope info + if isotope_labels: # Reset isotope info for a in mol.GetAtoms(): id = a.GetIsotope() a.SetIsotope(isotope_labels[id]) - + # __init__ self.MOL = mol self.smiles = Chem.MolToSmiles(mol) self.neutral_mass = Chem.Descriptors.ExactMolWt(mol) self.modes = constants.DEFAULT_MODES - self.mz = {mode: self.neutral_mass + constants.ADDUCT_WEIGHTS[mode] for mode in self.modes} - self.mz.update({mode.replace("]+", "]-"): self.neutral_mass + constants.ADDUCT_WEIGHTS[mode.replace("]+", "]-")] for mode in self.modes}) - + self.mz = { + mode: self.neutral_mass + constants.ADDUCT_WEIGHTS[mode] + for mode in self.modes + } + self.mz.update( + { + mode.replace("]+", "]-"): self.neutral_mass + + constants.ADDUCT_WEIGHTS[mode.replace("]+", "]-")] + for mode in self.modes + } + ) + def __eq__(self, __o: object) -> bool: if self.neutral_mass != __o.neutral_mass: return False return self.get_morganFinger() == __o.get_morganFinger() def __repr__(self): - return " :: " + self.smiles #+ " " + str(self.mz) - + return " :: " + self.smiles # + " " + str(self.mz) + def __str__(self): - return " :: " + self.smiles #+ " " + str(self.mz) + return " :: " + self.smiles # + " " + str(self.mz) def num_of_edges(self): return len(self.edges) @@ -65,7 +79,10 @@ def match_peak(self, mz, tolerance=None): def set_modes(self, modes): self.modes = modes - self.mz = {mode: self.neutral_mass + constants.ADDUCT_WEIGHTS[mode] for mode in self.modes} + self.mz = { + mode: self.neutral_mass + constants.ADDUCT_WEIGHTS[mode] + for mode in self.modes + } def set_ID(self, ID): self.ID = ID @@ -76,24 +93,27 @@ def get_tag(self): def get_morganFinger(self): return AllChem.GetMorganFingerprintAsBitVect(self.MOL, 2, nBits=1024) + class FragmentationTree: def __init__(self, root_mol): self.root_mol = root_mol self.edge_map = {None: Fragment(root_mol)} - self.patt = Chem.MolFromSmarts('[!$([NH]!@C(=O))&!D1&!$(*#*)]-&!@[!$([NH]!@C(=O))&!D1&!$(*#*)]') + self.patt = Chem.MolFromSmarts( + "[!$([NH]!@C(=O))&!D1&!$(*#*)]-&!@[!$([NH]!@C(=O))&!D1&!$(*#*)]" + ) def __repr__(self): self.fragmentation_tree.show(idhidden=False) return "" - + def __str__(self): self.fragmentation_tree.show(idhidden=False) return "" - ''' + """ Getter - ''' + """ def get_fragment(self, id): return self.fragmentation_tree.get_node(id).data @@ -101,62 +121,85 @@ def get_fragment(self, id): def set_fragment_modes(self, modes): for frag in self.get_all_fragments(): frag.set_modes(modes) - + def get_all_fragments_as_nodes(self): return self.fragmentation_tree.all_nodes() def get_all_fragments(self): return [x.data for x in self.fragmentation_tree.all_nodes()] - ''' + """ Core methods - ''' + """ - def build_fragmentation_tree(self, mol, edge_indices, depth=2, parent_tree=None, parent_id=None): + def build_fragmentation_tree( + self, mol, edge_indices, depth=2, parent_tree=None, parent_id=None + ): self.fragmentation_tree = Tree(tree=parent_tree) root_fragment = Fragment(mol) root_fragment.set_ID(self.fragmentation_tree.size()) - - mol_isotopes = [a.GetIsotope() for a in mol.GetAtoms()] - for i, atom in enumerate(mol.GetAtoms()): atom.SetIsotope(i) # use isotope information as a proxy for atom_id (such that the information is not lost when carrying out the bond break) + mol_isotopes = [a.GetIsotope() for a in mol.GetAtoms()] + for i, atom in enumerate(mol.GetAtoms()): + atom.SetIsotope( + i + ) # use isotope information as a proxy for atom_id (such that the information is not lost when carrying out the bond break) - self.fragmentation_tree.create_node(tag=root_fragment.get_tag(), parent=parent_id, identifier=root_fragment.ID, data=root_fragment) - + self.fragmentation_tree.create_node( + tag=root_fragment.get_tag(), + parent=parent_id, + identifier=root_fragment.ID, + data=root_fragment, + ) listed_fragments = [] - for i,j in edge_indices: + for i, j in edge_indices: if i > j: continue - _, fragments = self.create_Fragments(mol, i, j, original_mol_isotopes=mol_isotopes) - self.edge_map[(i,j)] = {frag.break_sides[0]: frag for frag in fragments} # TODO Maybe choose to update edge_map later, this way the same fragment can exist multiple times due to multiple edges leading to the same fragments: Maybe fix. Maybe not. - - for f in fragments: + _, fragments = self.create_Fragments( + mol, i, j, original_mol_isotopes=mol_isotopes + ) + self.edge_map[(i, j)] = { + frag.break_sides[0]: frag for frag in fragments + } # TODO Maybe choose to update edge_map later, this way the same fragment can exist multiple times due to multiple edges leading to the same fragments: Maybe fix. Maybe not. + for f in fragments: if f is not None: f_existing = self.get_Fragment_if_in_list(f, listed_fragments) if f_existing: f_existing.edges.append(f.edges[0]) f_existing.break_sides.append(f.break_sides[0]) continue - if depth == 1: # anchor + if depth == 1: # anchor f.set_ID(self.fragmentation_tree.size()) - self.fragmentation_tree.create_node(tag=f.get_tag(), identifier=f.ID, parent=root_fragment.ID, data=f) - else: + self.fragmentation_tree.create_node( + tag=f.get_tag(), + identifier=f.ID, + parent=root_fragment.ID, + data=f, + ) + else: # build graph, adjacency matrix and index edges G = mol_to_graph(f.MOL) A = get_adjacency_matrix(G) edge_indices = get_edges(A) f.set_ID(self.fragmentation_tree.size()) - self.fragmentation_tree = self.build_fragmentation_tree(f.MOL, edge_indices, depth=depth-1, parent_tree=self.fragmentation_tree, parent_id=root_fragment.ID) + self.fragmentation_tree = self.build_fragmentation_tree( + f.MOL, + edge_indices, + depth=depth - 1, + parent_tree=self.fragmentation_tree, + parent_id=root_fragment.ID, + ) listed_fragments.append(f) - - for i, atom in enumerate(mol.GetAtoms()): atom.SetIsotope(mol_isotopes[i]) # Reset isotope information + + for i, atom in enumerate(mol.GetAtoms()): + atom.SetIsotope(mol_isotopes[i]) # Reset isotope information return self.fragmentation_tree def create_Fragments(self, mol, i, j, original_mol_isotopes=None): - + bond = mol.GetBondBetweenAtoms(int(i), int(j)) if bond.IsInRing(): return None, [] @@ -166,11 +209,13 @@ def create_Fragments(self, mol, i, j, original_mol_isotopes=None): new_mol = None fragment_mols = [] else: - if len(fragment_mols) < 1: #TODO resolve ring break - #fragment_mols = [fragment_mols[0]] + if len(fragment_mols) < 1: # TODO resolve ring break + # fragment_mols = [fragment_mols[0]] pass - return new_mol, [Fragment(m, edge=(int(i), int(j)), isotope_labels=original_mol_isotopes) for m in fragment_mols] - + return new_mol, [ + Fragment(m, edge=(int(i), int(j)), isotope_labels=original_mol_isotopes) + for m in fragment_mols + ] def is_Fragment_in_list(self, fragment, fragment_list): for f in fragment_list: @@ -184,13 +229,12 @@ def get_Fragment_if_in_list(self, fragment, fragment_list): return f return None - def match_peak_list(self, mz_list, int_list=None, tolerance=None): fragments = self.get_all_fragments() matches = {} if not int_list: - int_list = [0] * len(mz_list) - + int_list = [0] * len(mz_list) + # Compare mz list to all fragments for i, mz in enumerate(mz_list): was_peak_matched_already = False @@ -198,27 +242,34 @@ def match_peak_list(self, mz_list, int_list=None, tolerance=None): does_match, frag_ion = frag.match_peak(mz, tolerance=tolerance) if does_match: if was_peak_matched_already: - matches[mz]['fragments'] += [frag] # Report fragment for each edge leading to it - matches[mz]['ion_modes'] += [frag_ion] - else: + matches[mz]["fragments"] += [ + frag + ] # Report fragment for each edge leading to it + matches[mz]["ion_modes"] += [frag_ion] + else: matches[mz] = { - 'intensity': int_list[i] if int_list else None, - 'fragments': [frag], # Report fragment for each edge leading to it - 'ion_modes': [frag_ion] - } + "intensity": int_list[i] if int_list else None, + "fragments": [ + frag + ], # Report fragment for each edge leading to it + "ion_modes": [frag_ion], + } was_peak_matched_already = True - + # Normalize intensity values # if sum_matched == 0: # return matches - - sum_intensity = sum([m["intensity"] for mz, m in matches.items() if m["intensity"] is not None]) + + sum_intensity = sum( + [m["intensity"] for mz, m in matches.items() if m["intensity"] is not None] + ) if sum_intensity > 0: for mz in matches.keys(): - int_value = matches[mz]['intensity'] - matches[mz]['relative_intensity'] = int_value / sum_intensity # only considered matched peaks - - + int_value = matches[mz]["intensity"] + matches[mz]["relative_intensity"] = ( + int_value / sum_intensity + ) # only considered matched peaks + # for mz in matches.keys(): # int_value = matches[mz]['intensity'] # matches[mz]['relative_intensity'] = int_value / sum_matched # only considered matched peaks @@ -228,14 +279,9 @@ def match_peak_list(self, mz_list, int_list=None, tolerance=None): return matches - - - - - ''' + """ Old methods - ''' - + """ def build_fragmentation_tree_from_fraggraph_df(self, df): return @@ -245,100 +291,123 @@ def build_fragmentation_tree_by_rotatable_bond_breaks(self, depth=1): mol = self.root_mol em = Chem.EditableMol(mol) nAts = mol.GetNumAtoms() - for a,b in bonds: - em.RemoveBond(a,b) + for a, b in bonds: + em.RemoveBond(a, b) em.AddAtom(Chem.Atom(0)) - em.AddBond(a,nAts,Chem.BondType.SINGLE) + em.AddBond(a, nAts, Chem.BondType.SINGLE) em.AddAtom(Chem.Atom(0)) - em.AddBond(b,nAts+1,Chem.BondType.SINGLE) - nAts+=2 + em.AddBond(b, nAts + 1, Chem.BondType.SINGLE) + nAts += 2 p = em.GetMol() Chem.SanitizeMol(p) - smis = [Chem.MolToSmiles(x,True) for x in Chem.GetMolFrags(p,asMols=True)] - for smi in smis: print(smi) + smis = [Chem.MolToSmiles(x, True) for x in Chem.GetMolFrags(p, asMols=True)] + for smi in smis: + print(smi) - - #from rdkit.Chem import BRICS - #bonds = [((x,y),(0,0)) for x,y in bonds] + # from rdkit.Chem import BRICS + # bonds = [((x,y),(0,0)) for x,y in bonds] + # # - # - #p = BRICS.BreakBRICSBonds(mol,bonds=bonds) + # p = BRICS.BreakBRICSBonds(mol,bonds=bonds) # - #smis = [Chem.MolToSmiles(x,True) for x in Chem.GetMolFrags(p,asMols=True)] - #for smi in smis: print(smi) + # smis = [Chem.MolToSmiles(x,True) for x in Chem.GetMolFrags(p,asMols=True)] + # for smi in smis: print(smi) # return smi - - def build_fragmentation_tree_by_single_edge_breaks(self, mol, edge_indices, depth=2, parent_tree=None, parent_id=None): + def build_fragmentation_tree_by_single_edge_breaks( + self, mol, edge_indices, depth=2, parent_tree=None, parent_id=None + ): self.fragmentation_tree = Tree(tree=parent_tree) ID = self.fragmentation_tree.size() - self.fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(mol), parent=parent_id, identifier=ID, data=mol) - + self.fragmentation_tree.create_node( + tag=Chem.Descriptors.ExactMolWt(mol), + parent=parent_id, + identifier=ID, + data=mol, + ) listed_fragments = [] - - for i,j in edge_indices: + for i, j in edge_indices: _, fragments = self.create_fragments(mol, i, j) for f in fragments: if f is not None: if self.is_fragment_in_list(f, listed_fragments): continue - if depth == 1: # anchor - self.fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(f), identifier=self.fragmentation_tree.size(), parent=ID, data=f) - else: # recursion TODO OPTIMIZE + if depth == 1: # anchor + self.fragmentation_tree.create_node( + tag=Chem.Descriptors.ExactMolWt(f), + identifier=self.fragmentation_tree.size(), + parent=ID, + data=f, + ) + else: # recursion TODO OPTIMIZE # build graph, adjacency matrix and index edges G = mol_to_graph(f) A = get_adjacency_matrix(G) - edge_indices = get_edge_indices(A) - self.fragmentation_tree = self.build_fragmentation_tree_by_single_edge_breaks(f, edge_indices, depth=depth-1, parent_tree=self.fragmentation_tree, parent_id=ID) + edge_indices = (A > 0.001).nonzero(as_tuple=False) + edge_indices = edge_indices[ + edge_indices[:, 0] < edge_indices[:, 1] + ].tolist() + self.fragmentation_tree = ( + self.build_fragmentation_tree_by_single_edge_breaks( + f, + edge_indices, + depth=depth - 1, + parent_tree=self.fragmentation_tree, + parent_id=ID, + ) + ) listed_fragments.append(f) return self.fragmentation_tree - def break_bond(self, mol, i,j, add_dummy_atoms=False): + def break_bond(self, mol, i, j, add_dummy_atoms=False): num_atoms = mol.GetNumAtoms() em = Chem.EditableMol(mol) em.RemoveBond(i, j) - + if add_dummy_atoms: - em.AddAtom(Chem.Atom(0)) # - em.AddBond(i,num_atoms,Chem.BondType.SINGLE) # - em.AddAtom(Chem.Atom(0)) # - em.AddBond(j,num_atoms+1,Chem.BondType.SINGLE) # - - new_mol = em.GetMol() - Chem.SanitizeMol(new_mol) # + em.AddAtom(Chem.Atom(0)) # + em.AddBond(i, num_atoms, Chem.BondType.SINGLE) # + em.AddAtom(Chem.Atom(0)) # + em.AddBond(j, num_atoms + 1, Chem.BondType.SINGLE) # + + new_mol = em.GetMol() + Chem.SanitizeMol(new_mol) # frags = Chem.GetMolFrags(new_mol, asMols=True) return new_mol, frags - def create_fragments(self, mol, i, j): try: new_mol, fragments = self.break_bond(mol, int(i), int(j)) except (Chem.AtomKekulizeException, Chem.KekulizeException): - #print(i,j, "Error", Chem.AtomKekulizeException) + # print(i,j, "Error", Chem.AtomKekulizeException) new_mol = None fragments = [None, None] else: if len(fragments) < 1: - #TODO resolve ring break + # TODO resolve ring break fragments = [fragments[0], None] - return new_mol, fragments + return new_mol, fragments def morganFinger(self, x): return AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) def equalMols(self, mol, other): - funcs = [Chem.Descriptors.ExactMolWt, self.morganFinger, AllChem.GetMACCSKeysFingerprint] - #func = Chem.Descriptors.ExactMolWt # TODO add more here !!!!! When are mols equal???? + funcs = [ + Chem.Descriptors.ExactMolWt, + self.morganFinger, + AllChem.GetMACCSKeysFingerprint, + ] + # func = Chem.Descriptors.ExactMolWt # TODO add more here !!!!! When are mols equal???? for func in funcs: if func(mol) == func(other): continue diff --git a/fiora/MOL/Metabolite.py b/fiora/MOL/Metabolite.py index f169ef5..de948d3 100644 --- a/fiora/MOL/Metabolite.py +++ b/fiora/MOL/Metabolite.py @@ -16,44 +16,70 @@ import networkx as nx -from fiora.MOL.constants import DEFAULT_PPM, DEFAULT_MODES, DEFAULT_MODE_MAP, ADDUCT_WEIGHTS, ORDERED_ELEMENT_LIST_WITH_HYDROGEN, MAX_SUBGRAPH_NODES -from fiora.MOL.mol_graph import mol_to_graph, get_adjacency_matrix, get_degree_matrix, get_edges, get_identity_matrix, draw_graph, compute_edge_related_helper_matrices, get_helper_matrices_from_edges -from fiora.MOL.FragmentationTree import FragmentationTree +from fiora.MOL.constants import ( + DEFAULT_PPM, + DEFAULT_MODES, + DEFAULT_MODE_MAP, + ADDUCT_WEIGHTS, + ORDERED_ELEMENT_LIST_WITH_HYDROGEN, + MAX_SUBGRAPH_NODES, +) +from fiora.MOL.mol_graph import ( + mol_to_graph, + get_adjacency_matrix, + get_degree_matrix, + get_edges, + get_identity_matrix, + draw_graph, + compute_edge_related_helper_matrices, + get_helper_matrices_from_edges, +) +from fiora.MOL.FragmentationTree import FragmentationTree from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder class Metabolite: - def __init__(self, SMILES: str|None, InChI: str|None=None, id: int|None=None) -> None: + def __init__( + self, SMILES: str | None, InChI: str | None = None, id: int | None = None + ) -> None: if SMILES: self.SMILES = SMILES self.MOL = Chem.MolFromSmiles(self.SMILES) if not self.MOL: - raise AssertionError("Molecule invalid; could not be generated from SMILES") + raise AssertionError( + "Molecule invalid; could not be generated from SMILES" + ) self.InChI = Chem.MolToInchi(self.MOL) self.InChIKey = Chem.InchiToInchiKey(self.InChI) elif InChI: self.InChI = InChI self.MOL = Chem.MolFromInchi(self.InChI) if not self.MOL: - raise AssertionError("Molecule invalid; could not be generated from InChI") + raise AssertionError( + "Molecule invalid; could not be generated from InChI" + ) self.InChIKey = Chem.InchiToInchiKey(self.InChI) - self.SMILES = Chem.MolToSmiles(self.MOL) + self.SMILES = Chem.MolToSmiles(self.MOL) else: raise ValueError("Neither SMILES nor InChI were specified.") - + self.ExactMolWeight = Descriptors.ExactMolWt(self.MOL) self.Formula = rdMolDescriptors.CalcMolFormula(self.MOL) - self.morganFinger = AllChem.GetMorganFingerprintAsBitVect(self.MOL, 2, nBits=2048) #1024 - self.morganFinger3 = AllChem.GetMorganFingerprintAsBitVect(self.MOL, 3, nBits=2048) #1024 + self.morganFinger = AllChem.GetMorganFingerprintAsBitVect( + self.MOL, 2, nBits=2048 + ) # 1024 + self.morganFinger3 = AllChem.GetMorganFingerprintAsBitVect( + self.MOL, 3, nBits=2048 + ) # 1024 self.morganFingerCountOnes = self.morganFinger.GetNumOnBits() self.id = id self.loss_weight = 1.0 def __repr__(self): return f"" - + def __str__(self): return f"" @@ -61,12 +87,14 @@ def __eq__(self, __o: object) -> bool: if self.ExactMolWeight != __o.ExactMolWeight: return False # Compare the number of bits=1 to prefilter mismatching Metabolites, since it is a much faster comparison - if self.morganFingerCountOnes != __o.morganFingerCountOnes: + if self.morganFingerCountOnes != __o.morganFingerCountOnes: return False return self.get_morganFinger() == __o.get_morganFinger() - def __lt__(self, __o: object) -> bool: # TODO not tested!s - warnings.warn("Warning: < operation for Metabolite class is not tested. Potentially flawed.") + def __lt__(self, __o: object) -> bool: # TODO not tested!s + warnings.warn( + "Warning: < operation for Metabolite class is not tested. Potentially flawed." + ) if self.ExactMolWeight < __o.ExactMolWeight: return True for bit_this, bit_other in zip(self.get_morganFinger(), __o.get_morganFinger()): @@ -75,42 +103,50 @@ def __lt__(self, __o: object) -> bool: # TODO not tested!s elif bit_other < bit_this: return False return False - + def get_id(self): return self.id - + def set_id(self, id): self.id = id def set_loss_weight(self, weight): self.loss_weight = weight - def get_theoretical_precursor_mz(self, ion_type: str=None): + def get_theoretical_precursor_mz(self, ion_type: str = None): if ion_type is None: - if hasattr(self, 'metadata') and 'precursor_mode' in self.metadata: - ion_type = self.metadata['precursor_mode'] + if hasattr(self, "metadata") and "precursor_mode" in self.metadata: + ion_type = self.metadata["precursor_mode"] else: - raise ValueError("Ion type is not specified and no precursor_mode found in metadata.") + raise ValueError( + "Ion type is not specified and no precursor_mode found in metadata." + ) return self.ExactMolWeight + ADDUCT_WEIGHTS[ion_type] def get_morganFinger(self): return self.morganFinger - - def tanimoto_similarity(self, __o: object, finger: Literal["morgan2", "morgan3"]="morgan2"): + + def tanimoto_similarity( + self, __o: object, finger: Literal["morgan2", "morgan3"] = "morgan2" + ): if finger == "morgan2": - return DataStructs.TanimotoSimilarity(self.get_morganFinger(), __o.get_morganFinger()) + return DataStructs.TanimotoSimilarity( + self.get_morganFinger(), __o.get_morganFinger() + ) if finger == "morgan3": return DataStructs.TanimotoSimilarity(self.morganFinger3, __o.morganFinger3) - raise ValueError(f"Unknown type of fingerprint: {finger}. Cannot compare Metabolites.") + raise ValueError( + f"Unknown type of fingerprint: {finger}. Cannot compare Metabolites." + ) - def draw(self, ax=plt, show: bool=False, high_res: bool=False): + def draw(self, ax=plt, show: bool = False, high_res: bool = False): if high_res: # Generate high-resolution SVG drawer = rdMolDraw2D.MolDraw2DSVG(500, 500) drawer.DrawMolecule(self.MOL) drawer.FinishDrawing() - img = SVG(drawer.GetDrawingText()) - + img = SVG(drawer.GetDrawingText()) + # Display the SVG inline in the notebook if show: display(img) @@ -119,24 +155,32 @@ def draw(self, ax=plt, show: bool=False, high_res: bool=False): # Generate low-resolution image img = Draw.MolToImage(self.MOL, ax=ax) ax.grid(False) - ax.tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) + ax.tick_params( + axis="both", + bottom=False, + labelbottom=False, + left=False, + labelleft=False, + ) ax.imshow(img) ax.axis("off") if show: plt.show() return img - - # class-specific functions def create_molecular_structure_graph(self): self.Graph: nx.Graph = mol_to_graph(self.MOL) - - - def compute_graph_attributes(self, node_encoder: AtomFeatureEncoder|None = None, bond_encoder: BondFeatureEncoder|None = None, memory_safe: bool = False) -> None: + + def compute_graph_attributes( + self, + node_encoder: AtomFeatureEncoder | None = None, + bond_encoder: BondFeatureEncoder | None = None, + memory_safe: bool = False, + ) -> None: # Adjacency - A = get_adjacency_matrix(self.Graph) + A = get_adjacency_matrix(self.Graph) self.edges = A.nonzero() self.edges_as_tuples = get_edges(A) @@ -147,57 +191,133 @@ def compute_graph_attributes(self, node_encoder: AtomFeatureEncoder|None = None, # self.Anorm = self.A / self.deg # self.AL, self.AR, self.edges = compute_edge_related_helper_matrices(self.A, self.deg) # self.AL, self.AR = get_helper_matrices_from_edges(self.edges_as_tuples, self.A) - + # Labels - self.is_node_aromatic = torch.tensor([[self.Graph.nodes[atom]['is_aromatic'] for atom in self.Graph.nodes()]], dtype=torch.float32).t() - self.is_edge_aromatic = torch.tensor([[self.Graph[u][v]['bond_type'].name == "AROMATIC" for u,v in self.edges_as_tuples]], dtype=torch.float32).t() - self.is_edge_in_ring = torch.tensor([[self.Graph[u][v]['bond'].IsInRing() for u,v in self.edges_as_tuples]], dtype=torch.float32).t() - self.is_edge_not_in_ring = torch.tensor([[not self.Graph[u][v]['bond'].IsInRing() for u,v in self.edges_as_tuples]], dtype=torch.float32).t() + self.is_node_aromatic = torch.tensor( + [[self.Graph.nodes[atom]["is_aromatic"] for atom in self.Graph.nodes()]], + dtype=torch.float32, + ).t() + self.is_edge_aromatic = torch.tensor( + [ + [ + self.Graph[u][v]["bond_type"].name == "AROMATIC" + for u, v in self.edges_as_tuples + ] + ], + dtype=torch.float32, + ).t() + self.is_edge_in_ring = torch.tensor( + [[self.Graph[u][v]["bond"].IsInRing() for u, v in self.edges_as_tuples]], + dtype=torch.float32, + ).t() + self.is_edge_not_in_ring = torch.tensor( + [ + [ + not self.Graph[u][v]["bond"].IsInRing() + for u, v in self.edges_as_tuples + ] + ], + dtype=torch.float32, + ).t() self.ring_proportion = sum(self.is_edge_in_ring) / len(self.is_edge_in_ring) - self.edge_forward_direction = torch.tensor([[bool(u < v) for u,v in self.edges_as_tuples]], dtype=torch.bool).t() - self.edge_backward_direction = torch.tensor([[bool(u > v) for u,v in self.edges_as_tuples]], dtype=torch.bool).t() - + self.edge_forward_direction = torch.tensor( + [[bool(u < v) for u, v in self.edges_as_tuples]], dtype=torch.bool + ).t() + self.edge_backward_direction = torch.tensor( + [[bool(u > v) for u, v in self.edges_as_tuples]], dtype=torch.bool + ).t() + # Lists if not memory_safe: - self.atoms_in_order = [self.Graph.nodes[atom]['atom'] for atom in self.Graph.nodes()] - self.node_elements = [self.Graph.nodes[atom]['atom'].GetSymbol() for atom in self.Graph.nodes()] - self.edge_bond_names = [self.Graph[u][v]['bond_type'].name for u,v in self.edges_as_tuples] - + self.atoms_in_order = [ + self.Graph.nodes[atom]["atom"] for atom in self.Graph.nodes() + ] + self.node_elements = [ + self.Graph.nodes[atom]["atom"].GetSymbol() + for atom in self.Graph.nodes() + ] + self.edge_bond_names = [ + self.Graph[u][v]["bond_type"].name for u, v in self.edges_as_tuples + ] + # Features if node_encoder: self.node_features = node_encoder.encode(self.Graph, encoder_type="number") - self.node_features_one_hot = node_encoder.encode(self.Graph, encoder_type="one_hot") + self.node_features_one_hot = node_encoder.encode( + self.Graph, encoder_type="one_hot" + ) if bond_encoder: - self.edge_bond_types = torch.tensor([bond_encoder.number_mapper["bond_type"][bond_name] for bond_name in self.edge_bond_names], dtype=torch.int64) - self.bond_features = bond_encoder.encode(self.Graph, self.edges_as_tuples, encoder_type="number") - self.bond_features_one_hot = bond_encoder.encode(self.Graph, self.edges_as_tuples, encoder_type="one_hot") + self.edge_bond_types = torch.tensor( + [ + bond_encoder.number_mapper["bond_type"][bond_name] + for bond_name in self.edge_bond_names + ], + dtype=torch.int64, + ) + self.bond_features = bond_encoder.encode( + self.Graph, self.edges_as_tuples, encoder_type="number" + ) + self.bond_features_one_hot = bond_encoder.encode( + self.Graph, self.edges_as_tuples, encoder_type="one_hot" + ) else: - self.bond_features = torch.zeros(len(self.edges_as_tuples), 0, dtype=torch.float32) - - def add_metadata(self, metadata, covariate_encoder: CovariateFeatureEncoder=None, rt_feature_encoder: CovariateFeatureEncoder=None, process_metadata: bool = True, max_RT=30.0): + self.bond_features = torch.zeros( + len(self.edges_as_tuples), 0, dtype=torch.float32 + ) + + def add_metadata( + self, + metadata, + covariate_encoder: CovariateFeatureEncoder = None, + rt_feature_encoder: CovariateFeatureEncoder = None, + process_metadata: bool = True, + max_RT=30.0, + ): self.metadata = metadata mol_metadata = {"molecular_weight": self.ExactMolWeight} metadata.update(mol_metadata) if not process_metadata: return - + if covariate_encoder: self.setup_features = covariate_encoder.encode(1, metadata, G=self.Graph) - self.setup_features_per_edge = covariate_encoder.encode(len(self.edges_as_tuples), metadata, G=self.Graph) + self.setup_features_per_edge = covariate_encoder.encode( + len(self.edges_as_tuples), metadata, G=self.Graph + ) if "ce_steps" in metadata: - self.ce_steps = torch.tensor([covariate_encoder.normalize_collision_steps(metadata["ce_steps"]) + [np.nan for _ in range(7 - len(metadata["ce_steps"]))]]) # nan padding + self.ce_steps = torch.tensor( + [ + covariate_encoder.normalize_collision_steps( + metadata["ce_steps"] + ) + + [np.nan for _ in range(7 - len(metadata["ce_steps"]))] + ] + ) # nan padding else: - self.ce_steps = torch.tensor([np.nan] * 7, dtype=torch.float).unsqueeze(0) - self.ce_idx = torch.tensor(covariate_encoder.one_hot_mapper["collision_energy"], dtype=int).unsqueeze(dim=-1) + self.ce_steps = torch.tensor([np.nan] * 7, dtype=torch.float).unsqueeze( + 0 + ) + self.ce_idx = torch.tensor( + covariate_encoder.one_hot_mapper["collision_energy"], dtype=int + ).unsqueeze(dim=-1) else: self.setup_features = torch.zeros(1, 0, dtype=torch.float32) - self.setup_features_per_edge = torch.zeros(len(self.edges_as_tuples), 0, dtype=torch.float32) - + self.setup_features_per_edge = torch.zeros( + len(self.edges_as_tuples), 0, dtype=torch.float32 + ) + if rt_feature_encoder: - self.rt_setup_features = rt_feature_encoder.encode(1, metadata, G=self.Graph) - + self.rt_setup_features = rt_feature_encoder.encode( + 1, metadata, G=self.Graph + ) + if "retention_time" in metadata.keys(): - if not metadata["retention_time"] or np.isnan(metadata["retention_time"]) or "GC" in str(metadata["instrument"]) or metadata["retention_time"] > max_RT: + if ( + not metadata["retention_time"] + or np.isnan(metadata["retention_time"]) + or "GC" in str(metadata["instrument"]) + or metadata["retention_time"] > max_RT + ): metadata["retention_time"] = np.nan self.rt = torch.tensor([np.nan]).unsqueeze(dim=-1) self.rt_mask = torch.tensor([0], dtype=torch.bool).unsqueeze(dim=-1) @@ -207,9 +327,13 @@ def add_metadata(self, metadata, covariate_encoder: CovariateFeatureEncoder=None else: self.rt = torch.tensor([torch.nan]).unsqueeze(dim=-1) self.rt_mask = torch.tensor([0], dtype=torch.bool).unsqueeze(dim=-1) - + if "ccs" in metadata.keys(): - if not metadata["ccs"] or np.isnan(metadata["ccs"]) or "GC" in str(metadata["instrument"]): + if ( + not metadata["ccs"] + or np.isnan(metadata["ccs"]) + or "GC" in str(metadata["instrument"]) + ): metadata["ccs"] = np.nan self.ccs_mask = torch.tensor([0], dtype=torch.bool).unsqueeze(dim=-1) self.ccs = torch.tensor([np.nan]).unsqueeze(dim=-1) @@ -225,30 +349,44 @@ def is_single_connected_structure(self): if len(fragments) == 1: return True return False - + def fragment_MOL(self, depth=1): self.fragmentation_tree = FragmentationTree(self.MOL) - self.fragmentation_tree.build_fragmentation_tree(self.MOL, self.edges_as_tuples, depth=depth) + self.fragmentation_tree.build_fragmentation_tree( + self.MOL, self.edges_as_tuples, depth=depth + ) self.extract_subgraph_features_from_edges() - + def add_fragmentation_tree(self, fragmentation_tree: FragmentationTree): self.fragmentation_tree = fragmentation_tree self.extract_subgraph_features_from_edges() def extract_subgraph_features_from_edges(self) -> None: if self.fragmentation_tree is None: - self.subgraph_elem_comp = torch.zeros(0, 2 * len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN), dtype=torch.float32) - self.subgraph_idx_left = torch.full((0, MAX_SUBGRAPH_NODES), -1, dtype=torch.int64) - self.subgraph_idx_right = torch.full((0, MAX_SUBGRAPH_NODES), -1, dtype=torch.int64) + self.subgraph_elem_comp = torch.zeros( + 0, 2 * len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN), dtype=torch.float32 + ) + self.subgraph_idx_left = torch.full( + (0, MAX_SUBGRAPH_NODES), -1, dtype=torch.int64 + ) + self.subgraph_idx_right = torch.full( + (0, MAX_SUBGRAPH_NODES), -1, dtype=torch.int64 + ) return edge_map = self.fragmentation_tree.edge_map num_edges = len(self.edges) # Initialize tensors for element composition and subgraph node indices - self.subgraph_elem_comp = torch.zeros(num_edges, 2 * len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN), dtype=torch.float32) - self.subgraph_idx_left = torch.full((num_edges, MAX_SUBGRAPH_NODES), -1, dtype=torch.int64) - self.subgraph_idx_right = torch.full((num_edges, MAX_SUBGRAPH_NODES), -1, dtype=torch.int64) + self.subgraph_elem_comp = torch.zeros( + num_edges, 2 * len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN), dtype=torch.float32 + ) + self.subgraph_idx_left = torch.full( + (num_edges, MAX_SUBGRAPH_NODES), -1, dtype=torch.int64 + ) + self.subgraph_idx_right = torch.full( + (num_edges, MAX_SUBGRAPH_NODES), -1, dtype=torch.int64 + ) for i, edge in enumerate(self.edges): u, v = edge[0].item(), edge[1].item() @@ -269,21 +407,43 @@ def extract_subgraph_features_from_edges(self) -> None: right_fragment = frag_list["left"] # Initialize element composition for the edge - edge_elem_comp = torch.zeros(2 * len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN), dtype=torch.float32) + edge_elem_comp = torch.zeros( + 2 * len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN), dtype=torch.float32 + ) if left_fragment is not None and right_fragment is not None: # Compute element composition for left and right fragments - edge_elem_comp[:len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN)] = CovariateFeatureEncoder.get_element_composition(self.Graph.subgraph(left_fragment.subgraphs[0])).clone().detach().to(torch.float32) - edge_elem_comp[len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN):] = CovariateFeatureEncoder.get_element_composition(self.Graph.subgraph(right_fragment.subgraphs[0])).clone().detach().to(torch.float32) - + edge_elem_comp[: len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN)] = ( + CovariateFeatureEncoder.get_element_composition( + self.Graph.subgraph(left_fragment.subgraphs[0]) + ) + .clone() + .detach() + .to(torch.float32) + ) + edge_elem_comp[len(ORDERED_ELEMENT_LIST_WITH_HYDROGEN) :] = ( + CovariateFeatureEncoder.get_element_composition( + self.Graph.subgraph(right_fragment.subgraphs[0]) + ) + .clone() + .detach() + .to(torch.float32) + ) + # Get node indices and truncate if necessary left_nodes = torch.tensor(left_fragment.subgraphs[0], dtype=torch.int64) - right_nodes = torch.tensor(right_fragment.subgraphs[0], dtype=torch.int64) - - if len(left_nodes) > MAX_SUBGRAPH_NODES or len(right_nodes) > MAX_SUBGRAPH_NODES: - warnings.warn(f"Metabolite {self.SMILES}: Subgraph size ({max(len(left_nodes), len(right_nodes))}) exceeds MAX_SUBGRAPH_NODES ({MAX_SUBGRAPH_NODES}). Truncating.") + right_nodes = torch.tensor( + right_fragment.subgraphs[0], dtype=torch.int64 + ) + + if ( + len(left_nodes) > MAX_SUBGRAPH_NODES + or len(right_nodes) > MAX_SUBGRAPH_NODES + ): + warnings.warn( + f"Metabolite {self.SMILES}: Subgraph size ({max(len(left_nodes), len(right_nodes))}) exceeds MAX_SUBGRAPH_NODES ({MAX_SUBGRAPH_NODES}). Truncating." + ) - len_left = min(len(left_nodes), MAX_SUBGRAPH_NODES) len_right = min(len(right_nodes), MAX_SUBGRAPH_NODES) @@ -292,13 +452,38 @@ def extract_subgraph_features_from_edges(self) -> None: # Store the element composition for the edge self.subgraph_elem_comp[i, :] = edge_elem_comp - - def match_fragments_to_peaks(self, mz_fragments, int_list=None, mode_map_override=None, tolerance=DEFAULT_PPM, match_stats_only: bool = False): - self.peak_matches = self.fragmentation_tree.match_peak_list(mz_fragments, int_list, tolerance=tolerance) - self.edge_breaks = [frag.edges for mz in self.peak_matches.keys() for frag in self.peak_matches[mz]['fragments']] - self.edge_breaks = [e for edges in self.edge_breaks for e in edges] # Flatten the edge breaks - edge_break_labels = torch.tensor([[1.0 if (u, v) in self.edge_breaks or (v, u) in self.edge_breaks else 0.0 for u,v in self.edges_as_tuples]], dtype=torch.float32).t() - + + def match_fragments_to_peaks( + self, + mz_fragments, + int_list=None, + mode_map_override=None, + tolerance=DEFAULT_PPM, + match_stats_only: bool = False, + ): + self.peak_matches = self.fragmentation_tree.match_peak_list( + mz_fragments, int_list, tolerance=tolerance + ) + self.edge_breaks = [ + frag.edges + for mz in self.peak_matches.keys() + for frag in self.peak_matches[mz]["fragments"] + ] + self.edge_breaks = [ + e for edges in self.edge_breaks for e in edges + ] # Flatten the edge breaks + edge_break_labels = torch.tensor( + [ + [ + 1.0 + if (u, v) in self.edge_breaks or (v, u) in self.edge_breaks + else 0.0 + for u, v in self.edges_as_tuples + ] + ], + dtype=torch.float32, + ).t() + if mode_map_override: mode_map = mode_map_override else: @@ -306,108 +491,219 @@ def match_fragments_to_peaks(self, mz_fragments, int_list=None, mode_map_overrid # Flatten out all edges from fragments self.edge_intensities = [] - for mz in self.peak_matches.keys(): - intensity = self.peak_matches[mz]["intensity"] / sum(f.num_of_edges() for f in self.peak_matches[mz]["fragments"]) - self.peak_matches[mz]["edges"] = [e for f in self.peak_matches[mz]["fragments"] for e in f.edges] + for mz in self.peak_matches.keys(): + intensity = self.peak_matches[mz]["intensity"] / sum( + f.num_of_edges() for f in self.peak_matches[mz]["fragments"] + ) + self.peak_matches[mz]["edges"] = [ + e for f in self.peak_matches[mz]["fragments"] for e in f.edges + ] for i, f in enumerate(self.peak_matches[mz]["fragments"]): for j, edge in enumerate(f.edges): - entry = (edge, {'intensity': intensity, 'fragment': f, 'break_side': f.break_sides[j], 'ion_mode': self.peak_matches[mz]["ion_modes"][i][0]}) + entry = ( + edge, + { + "intensity": intensity, + "fragment": f, + "break_side": f.break_sides[j], + "ion_mode": self.peak_matches[mz]["ion_modes"][i][0], + }, + ) self.edge_intensities.append(entry) + self.edge_break_count = torch.zeros( + size=edge_break_labels.size(), dtype=torch.float32 + ) + self.precursor_count, self.precursor_prob, self.precursor_sqrt_prob = ( + torch.tensor(0.0), + torch.tensor(0.0), + torch.tensor(0.0), + ) + + self.edge_count_matrix = torch.zeros( + size=(edge_break_labels.shape[0], 2 * len(mode_map)), dtype=torch.float32 + ) - self.edge_break_count = torch.zeros(size = edge_break_labels.size(), dtype=torch.float32) - self.precursor_count, self.precursor_prob, self.precursor_sqrt_prob = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) - - - self.edge_count_matrix = torch.zeros(size = (edge_break_labels.shape[0], 2*len(mode_map)), dtype=torch.float32) - # Determining edge break probabilites from peak intensities. Multiple edges for the same fragment -> divide by number of edges. Multiple fragments from edge -> add intensities. for edge, values in self.edge_intensities: - if edge == None: # precursor - self.precursor_count += values['intensity'] + if edge == None: # precursor + self.precursor_count += values["intensity"] continue - edge_index = torch.logical_or(self.edges == torch.tensor(edge), self.edges == torch.tensor(edge[::-1])).all(dim=1).nonzero().squeeze() - self.edge_break_count[edge_index] += values['intensity'] - - forward_idx = ((torch.tensor(edge) == self.edges).sum(dim=1) == 2).nonzero().squeeze() - backward_idx = ((torch.tensor(edge[::-1]) == self.edges).sum(dim=1) == 2).nonzero().squeeze() - - - col = mode_map[values["ion_mode"]] if values["break_side"]=="left" else mode_map[values["ion_mode"]] + len(mode_map) - self.edge_count_matrix[forward_idx, col] = values['intensity'] - col = (col + len(mode_map)) % (2*len(mode_map)) #to the other side of the break - self.edge_count_matrix[backward_idx, col] = values['intensity'] - - #"bond_features_one_hot", - # Compile probability vectors + edge_index = ( + torch.logical_or( + self.edges == torch.tensor(edge), + self.edges == torch.tensor(edge[::-1]), + ) + .all(dim=1) + .nonzero() + .squeeze() + ) + self.edge_break_count[edge_index] += values["intensity"] + + forward_idx = ( + ((torch.tensor(edge) == self.edges).sum(dim=1) == 2).nonzero().squeeze() + ) + backward_idx = ( + ((torch.tensor(edge[::-1]) == self.edges).sum(dim=1) == 2) + .nonzero() + .squeeze() + ) + + col = ( + mode_map[values["ion_mode"]] + if values["break_side"] == "left" + else mode_map[values["ion_mode"]] + len(mode_map) + ) + self.edge_count_matrix[forward_idx, col] = values["intensity"] + col = (col + len(mode_map)) % ( + 2 * len(mode_map) + ) # to the other side of the break + self.edge_count_matrix[backward_idx, col] = values["intensity"] + + # "bond_features_one_hot", + # Compile probability vectors # self.compiled_counts = torch.cat([self.edge_break_count.flatten(), self.precursor_count.unsqueeze(dim=-1), self.precursor_count.unsqueeze(dim=-1)]) # self.compiled_probs = 2 * self.compiled_counts / torch.sum(self.compiled_counts) - + # COMPILED VECTORS COUNTS & PROBABILITIES FOR END-TO-END PREDICTION! Default is compiled_probsALL - self.compiled_countsALL = torch.cat([self.edge_count_matrix.flatten(), self.precursor_count.unsqueeze(dim=-1), self.precursor_count.unsqueeze(dim=-1)]) - self.compiled_probsALL = 2 * self.compiled_countsALL / torch.sum(self.compiled_countsALL) - + self.compiled_countsALL = torch.cat( + [ + self.edge_count_matrix.flatten(), + self.precursor_count.unsqueeze(dim=-1), + self.precursor_count.unsqueeze(dim=-1), + ] + ) + self.compiled_probsALL = ( + 2 * self.compiled_countsALL / torch.sum(self.compiled_countsALL) + ) + # SQRT transformation self.compiled_countsSQRT = torch.sqrt(self.compiled_countsALL) - self.compiled_probsSQRT = 2 * self.compiled_countsSQRT / torch.sum(self.compiled_countsSQRT) - + self.compiled_probsSQRT = ( + 2 * self.compiled_countsSQRT / torch.sum(self.compiled_countsSQRT) + ) - # MASKS # self.compiled_validation_mask = torch.cat([self.is_edge_not_in_ring.bool().squeeze(), torch.tensor([True, True], dtype=bool)], dim=-1) - self.compiled_validation_maskALL = torch.cat([torch.repeat_interleave(self.is_edge_not_in_ring.bool().squeeze(), len(mode_map)*2), torch.tensor([True, True], dtype=bool)], dim=-1) + self.compiled_validation_maskALL = torch.cat( + [ + torch.repeat_interleave( + self.is_edge_not_in_ring.bool().squeeze(), len(mode_map) * 2 + ), + torch.tensor([True, True], dtype=bool), + ], + dim=-1, + ) # self.compiled_forward_mask = torch.cat([self.edge_forward_direction.squeeze(), torch.tensor([True, False], dtype=bool)], dim=-1) - + # Track additional statistics max_intensity = max(int_list) intensity_filter_threshold = 0.01 self.match_stats = { - 'counts': self.compiled_countsALL.sum().tolist() / 2.0, # self.compiled_counts.sum().tolist() / 2.0, - 'ms_all_counts': sum(int_list), - 'coverage': (self.compiled_countsALL.sum().tolist() / 2.0) / sum(int_list), - 'coverage_wo_prec': (self.edge_break_count.sum().tolist() / 2.0) / (sum(int_list) - self.precursor_count.tolist()), - 'precursor_prob': self.precursor_count.tolist() / (self.compiled_countsALL.sum().tolist() / 2.0) if (self.compiled_countsALL.sum().tolist() / 2.0) > 0 else 0.0, - 'precursor_raw_prob': self.precursor_count.tolist() / sum(int_list), - 'num_peaks': len(mz_fragments), - 'num_peak_matches': len(self.peak_matches), - 'percent_peak_matches': len(self.peak_matches) / len(mz_fragments), - 'num_peaks_filtered': sum([(i / max_intensity) > intensity_filter_threshold for i in int_list]), - 'num_peak_matches_filtered': sum([match["relative_intensity"] > intensity_filter_threshold for mz, match in self.peak_matches.items()]), - 'percent_peak_matches_filtered': sum([match["relative_intensity"] > intensity_filter_threshold for mz, match in self.peak_matches.items()]) / len(mz_fragments), - 'num_non_precursor_matches': sum([(None not in match["edges"]) for mz, match in self.peak_matches.items()]), - 'num_peak_match_conflicts': sum([len(match["edges"]) > 1 for mz, match in self.peak_matches.items()]), - 'num_fragment_conflicts': sum([len(match["fragments"]) > 1 for mz, match in self.peak_matches.items()]), - 'rel_fragment_conflicts': sum([len(match["fragments"]) > 1 for mz, match in self.peak_matches.items()]) / sum([(None not in match["edges"]) for mz, match in self.peak_matches.items()]) if sum([(None not in match["edges"]) for mz, match in self.peak_matches.items()]) > 0 else 0, - 'ms_num_all_peaks': len(mz_fragments) + "counts": self.compiled_countsALL.sum().tolist() + / 2.0, # self.compiled_counts.sum().tolist() / 2.0, + "ms_all_counts": sum(int_list), + "coverage": (self.compiled_countsALL.sum().tolist() / 2.0) / sum(int_list), + "coverage_wo_prec": (self.edge_break_count.sum().tolist() / 2.0) + / (sum(int_list) - self.precursor_count.tolist()), + "precursor_prob": self.precursor_count.tolist() + / (self.compiled_countsALL.sum().tolist() / 2.0) + if (self.compiled_countsALL.sum().tolist() / 2.0) > 0 + else 0.0, + "precursor_raw_prob": self.precursor_count.tolist() / sum(int_list), + "num_peaks": len(mz_fragments), + "num_peak_matches": len(self.peak_matches), + "percent_peak_matches": len(self.peak_matches) / len(mz_fragments), + "num_peaks_filtered": sum( + [(i / max_intensity) > intensity_filter_threshold for i in int_list] + ), + "num_peak_matches_filtered": sum( + [ + match["relative_intensity"] > intensity_filter_threshold + for mz, match in self.peak_matches.items() + ] + ), + "percent_peak_matches_filtered": sum( + [ + match["relative_intensity"] > intensity_filter_threshold + for mz, match in self.peak_matches.items() + ] + ) + / len(mz_fragments), + "num_non_precursor_matches": sum( + [ + (None not in match["edges"]) + for mz, match in self.peak_matches.items() + ] + ), + "num_peak_match_conflicts": sum( + [len(match["edges"]) > 1 for mz, match in self.peak_matches.items()] + ), + "num_fragment_conflicts": sum( + [len(match["fragments"]) > 1 for mz, match in self.peak_matches.items()] + ), + "rel_fragment_conflicts": sum( + [len(match["fragments"]) > 1 for mz, match in self.peak_matches.items()] + ) + / sum( + [ + (None not in match["edges"]) + for mz, match in self.peak_matches.items() + ] + ) + if sum( + [ + (None not in match["edges"]) + for mz, match in self.peak_matches.items() + ] + ) + > 0 + else 0, + "ms_num_all_peaks": len(mz_fragments), } if match_stats_only: self.free_memory() def get_memory_usage(self): - memory_usage = {attr: sys.getsizeof(value) for attr, value in self.__dict__.items()} + memory_usage = { + attr: sys.getsizeof(value) for attr, value in self.__dict__.items() + } total_size = sum(memory_usage.values()) - return {"attributes": dict(sorted(memory_usage.items(), key=lambda x: x[1], reverse=True)), "total_size": total_size} + return { + "attributes": dict( + sorted(memory_usage.items(), key=lambda x: x[1], reverse=True) + ), + "total_size": total_size, + } def free_memory(self): attributes_to_free = [ - "edge_break_count", "precursor_count", - "precursor_prob", "precursor_sqrt_prob", "edge_count_matrix", - "compiled_countsALL", "compiled_probsALL", "compiled_countsSQRT", - "compiled_probsSQRT", "compiled_validation_maskALL", - "edge_breaks", "edge_intensities", "setup_features", - "setup_features_per_edge", "node_features", "node_features_one_hot", - "bond_features", "bond_features_one_hot" - ] # Tensors from peak matching - + "edge_break_count", + "precursor_count", + "precursor_prob", + "precursor_sqrt_prob", + "edge_count_matrix", + "compiled_countsALL", + "compiled_probsALL", + "compiled_countsSQRT", + "compiled_probsSQRT", + "compiled_validation_maskALL", + "edge_breaks", + "edge_intensities", + "setup_features", + "setup_features_per_edge", + "node_features", + "node_features_one_hot", + "bond_features", + "bond_features_one_hot", + ] # Tensors from peak matching for attr in attributes_to_free: if hasattr(self, attr): delattr(self, attr) - - def as_geometric_data(self, with_labels=True): if with_labels: return Data( @@ -415,66 +711,59 @@ def as_geometric_data(self, with_labels=True): edge_index=self.edges.t().contiguous(), edge_type=self.edge_bond_types, edge_attr=self.bond_features, - edge_elem_comp = self.subgraph_elem_comp, + edge_elem_comp=self.subgraph_elem_comp, subgraph_idx_left=self.subgraph_idx_left, subgraph_idx_right=self.subgraph_idx_right, static_graph_features=self.setup_features, static_edge_features=self.setup_features_per_edge, - static_rt_features = self.rt_setup_features, - - + static_rt_features=self.rt_setup_features, # labels - #y=self.edge_break_labels, + # y=self.edge_break_labels, compiled_probsALL=self.compiled_probsALL, compiled_probsSQRT=self.compiled_probsSQRT, # compiled_counts=self.compiled_counts, edge_break_count=self.edge_break_count, - #edge_break_prob=self.edge_break_prob, - #edge_break_prob_wo_precursor=self.edge_break_prob_wo_precursor, - #edge_break_sqrt_prob=self.edge_break_sqrt_prob, - #precursor_prob = self.precursor_prob, - retention_time = self.rt, - retention_mask = self.rt_mask, - ccs = self.ccs, - ccs_mask = self.ccs_mask, - + # edge_break_prob=self.edge_break_prob, + # edge_break_prob_wo_precursor=self.edge_break_prob_wo_precursor, + # edge_break_sqrt_prob=self.edge_break_sqrt_prob, + # precursor_prob = self.precursor_prob, + retention_time=self.rt, + retention_mask=self.rt_mask, + ccs=self.ccs, + ccs_mask=self.ccs_mask, # masks and groups validation_mask=self.is_edge_not_in_ring.bool(), # compiled_validation_mask = self.compiled_validation_mask, - compiled_validation_maskALL = self.compiled_validation_maskALL, - + compiled_validation_maskALL=self.compiled_validation_maskALL, # group identity and loss weights group_id=self.id, - weight = torch.tensor([self.loss_weight]).unsqueeze(dim=-1), - weight_tensor=torch.full(self.compiled_probsALL.shape, self.loss_weight), - + weight=torch.tensor([self.loss_weight]).unsqueeze(dim=-1), + weight_tensor=torch.full( + self.compiled_probsALL.shape, self.loss_weight + ), # Stepped collision energies - ce_steps = self.ce_steps, - ce_idx = self.ce_idx, # geom treats values with suffix _index differently -> avoid - - + ce_steps=self.ce_steps, + ce_idx=self.ce_idx, # geom treats values with suffix _index differently -> avoid # additional information is_node_aromatic=self.is_node_aromatic, - is_edge_aromatic=self.is_edge_aromatic - ) + is_edge_aromatic=self.is_edge_aromatic, + ) else: return Data( x=self.node_features, edge_index=self.edges.t().contiguous(), edge_type=self.edge_bond_types, edge_attr=self.bond_features, - edge_elem_comp = self.subgraph_elem_comp, + edge_elem_comp=self.subgraph_elem_comp, subgraph_idx_left=self.subgraph_idx_left, subgraph_idx_right=self.subgraph_idx_right, static_graph_features=self.setup_features, static_edge_features=self.setup_features_per_edge, - static_rt_features = self.rt_setup_features, - + static_rt_features=self.rt_setup_features, # masks and groups validation_mask=self.is_edge_not_in_ring.bool(), group_id=self.id, - # additional information is_node_aromatic=self.is_node_aromatic, - is_edge_aromatic=self.is_edge_aromatic - ) \ No newline at end of file + is_edge_aromatic=self.is_edge_aromatic, + ) diff --git a/fiora/MOL/MetaboliteDatasetStatistics.py b/fiora/MOL/MetaboliteDatasetStatistics.py index e130791..cb0f099 100644 --- a/fiora/MOL/MetaboliteDatasetStatistics.py +++ b/fiora/MOL/MetaboliteDatasetStatistics.py @@ -3,6 +3,7 @@ from fiora.MOL.Metabolite import Metabolite from fiora.MOL.constants import ORDERED_ELEMENT_LIST + class MetaboliteDatasetStatistics: def __init__(self, data: pd.DataFrame): """ @@ -20,29 +21,37 @@ def _compute_element_composition_stats(self): """ all_meta = [] for _, row in self.data.iterrows(): - metabolite = row['Metabolite'] + metabolite = row["Metabolite"] if isinstance(metabolite, Metabolite): # Create nested dictionaries for presence and count metadata_dict = { - "element_presence": {element: int(element in metabolite.node_elements) for element in ORDERED_ELEMENT_LIST}, - "element_count": {element: metabolite.node_elements.count(element) for element in ORDERED_ELEMENT_LIST}, + "element_presence": { + element: int(element in metabolite.node_elements) + for element in ORDERED_ELEMENT_LIST + }, + "element_count": { + element: metabolite.node_elements.count(element) + for element in ORDERED_ELEMENT_LIST + }, } # Add additional metadata metadata_dict["ExactMolWeight"] = metabolite.ExactMolWeight metadata_dict["Formula"] = metabolite.Formula metadata_dict["SMILES"] = metabolite.SMILES metadata_dict["InChIKey"] = metabolite.InChIKey - metadata_dict["TotalElements"] = len(metabolite.node_elements) # Total number of elements + metadata_dict["TotalElements"] = len( + metabolite.node_elements + ) # Total number of elements all_meta.append(metadata_dict) return pd.DataFrame(all_meta) - + def _compute_element_summary(self): """ Compute total counts, presence probability for each element, and ANY_RARE probability across the entire dataset. :return: dict with total counts, presence probabilities for each element, and ANY_RARE probability. """ - individual_stats = self.statistics['Individual_molecular_stats'] + individual_stats = self.statistics["Individual_molecular_stats"] # Initialize counters total_counts = Counter() @@ -50,12 +59,16 @@ def _compute_element_summary(self): any_rare_count = 0 # Counter for molecules with at least one rare element # Define rare elements (everything except C, O, N, H) - rare_elements = [element for element in ORDERED_ELEMENT_LIST if element not in ["C", "O", "N", "H"]] + rare_elements = [ + element + for element in ORDERED_ELEMENT_LIST + if element not in ["C", "O", "N", "H"] + ] # Aggregate counts and presence probabilities for _, row in individual_stats.iterrows(): - element_counts = row['element_count'] - element_presence = row['element_presence'] + element_counts = row["element_count"] + element_presence = row["element_presence"] total_counts.update(element_counts) presence_counts.update(element_presence) @@ -65,7 +78,10 @@ def _compute_element_summary(self): # Compute presence probabilities total_molecules = len(individual_stats) - presence_probabilities = {element: presence_counts[element] / total_molecules for element in ORDERED_ELEMENT_LIST} + presence_probabilities = { + element: presence_counts[element] / total_molecules + for element in ORDERED_ELEMENT_LIST + } # Compute ANY_RARE probability and add it as another "element" presence_probabilities["ANY_RARE"] = any_rare_count / total_molecules @@ -75,27 +91,25 @@ def _compute_element_summary(self): "Presence Probabilities": presence_probabilities, } - def generate_molecular_statistics(self, unique_compounds: bool = True): """ Precompute molecular statistics using the Metabolite class and store them in the class. """ # Retrieve detailed information for each metabolite if unique_compounds: - self.data = self.data.drop_duplicates(subset='group_id') - self.statistics['Individual_molecular_stats'] = self._compute_element_composition_stats() - self.statistics['Molecular Summary'] = self._compute_element_summary() - - - + self.data = self.data.drop_duplicates(subset="group_id") + self.statistics["Individual_molecular_stats"] = ( + self._compute_element_composition_stats() + ) + self.statistics["Molecular Summary"] = self._compute_element_summary() def _compute_duplicates(self): """ Compute duplicate occurrences based on 'group_id'. :return: pd.DataFrame with group_id counts. """ - group_counts = self.data['group_id'].value_counts().reset_index() - group_counts.columns = ['group_id', 'Count'] + group_counts = self.data["group_id"].value_counts().reset_index() + group_counts.columns = ["group_id", "Count"] return group_counts def get_statistics(self): @@ -104,5 +118,7 @@ def get_statistics(self): :return: dict containing all statistics. """ if not self.statistics: - raise ValueError("Statistics have not been generated yet. Call generate_molecular_statistics() first.") - return self.statistics \ No newline at end of file + raise ValueError( + "Statistics have not been generated yet. Call generate_molecular_statistics() first." + ) + return self.statistics diff --git a/fiora/MOL/MetaboliteIndex.py b/fiora/MOL/MetaboliteIndex.py index cc50c74..42830f3 100644 --- a/fiora/MOL/MetaboliteIndex.py +++ b/fiora/MOL/MetaboliteIndex.py @@ -2,14 +2,13 @@ from fiora.MOL.Metabolite import Metabolite from fiora.MOL.FragmentationTree import FragmentationTree -class MetaboliteIndex: +class MetaboliteIndex: def __init__(self) -> None: self.metabolite_index = {} def index_metabolites(self, list_of_metabolites: List) -> None: for metabolite in list_of_metabolites: - id = self.find_metabolite_id(metabolite) if id is not None: metabolite.set_id(id) @@ -18,32 +17,47 @@ def index_metabolites(self, list_of_metabolites: List) -> None: self.metabolite_index[new_id] = {"Metabolite": metabolite} metabolite.set_id(new_id) - def create_fragmentation_trees(self, depth: int=1) -> None: + def create_fragmentation_trees(self, depth: int = 1) -> None: for id, entry in self.metabolite_index.items(): metabolite = entry["Metabolite"] - entry["FragmentationTree"] = FragmentationTree(metabolite.MOL) - entry["FragmentationTree"].build_fragmentation_tree(metabolite.MOL, metabolite.edges_as_tuples, depth=depth) + entry["FragmentationTree"] = FragmentationTree(metabolite.MOL) + entry["FragmentationTree"].build_fragmentation_tree( + metabolite.MOL, metabolite.edges_as_tuples, depth=depth + ) - def add_fragmentation_trees_to_metabolite_list(self, list_of_metabolites: List[Metabolite], graph_mismatch_policy: Literal["ignore", "recompute"]="recompute") -> None: + def add_fragmentation_trees_to_metabolite_list( + self, + list_of_metabolites: List[Metabolite], + graph_mismatch_policy: Literal["ignore", "recompute"] = "recompute", + ) -> None: list_of_mismatched_ids = [] - + for metabolite in list_of_metabolites: id = metabolite.get_id() if id is not None: # Check if metabolite edges align with the index - if metabolite.edges_as_tuples == self.metabolite_index[id]["Metabolite"].edges_as_tuples: - metabolite.add_fragmentation_tree(self.metabolite_index[id]["FragmentationTree"]) + if ( + metabolite.edges_as_tuples + == self.metabolite_index[id]["Metabolite"].edges_as_tuples + ): + metabolite.add_fragmentation_tree( + self.metabolite_index[id]["FragmentationTree"] + ) else: if graph_mismatch_policy == "recompute": metabolite.fragment_MOL() elif graph_mismatch_policy == "ignore": - metabolite.add_fragmentation_tree(self.metabolite_index[id]["FragmentationTree"]) + metabolite.add_fragmentation_tree( + self.metabolite_index[id]["FragmentationTree"] + ) else: - raise ValueError("Invalid graph_mismatch_policy. Use 'ignore' or 'recompute'.") + raise ValueError( + "Invalid graph_mismatch_policy. Use 'ignore' or 'recompute'." + ) list_of_mismatched_ids.append((metabolite, id)) return list_of_mismatched_ids - + def find_metabolite_id(self, metabolite: Metabolite) -> int: for id, entry in self.metabolite_index.items(): if metabolite == entry["Metabolite"]: @@ -52,9 +66,9 @@ def find_metabolite_id(self, metabolite: Metabolite) -> int: def get_metabolite(self, id: int) -> Metabolite: return self.metabolite_index[id] - + def get_fragmentation_tree(self, id: int) -> FragmentationTree: return self.metabolite_index[id]["FragmentationTree"] - + def get_number_of_metabolites(self) -> int: - return len(self.metabolite_index) \ No newline at end of file + return len(self.metabolite_index) diff --git a/fiora/MOL/collision_energy.py b/fiora/MOL/collision_energy.py index 2938550..95cc739 100644 --- a/fiora/MOL/collision_energy.py +++ b/fiora/MOL/collision_energy.py @@ -1,13 +1,21 @@ charge_factor = {1: 1, 2: 0.9, 3: 0.85, 4: 0.8, 5: 0.75} -nce_instruments = ["Orbitrap", "LC-ESI-QFT", "LC-APCI-ITFT", "Linear Ion Trap", "LC-ESI-ITFT"] # "Flow-injection QqQ/MS", +nce_instruments = [ + "Orbitrap", + "LC-ESI-QFT", + "LC-APCI-ITFT", + "Linear Ion Trap", + "LC-ESI-ITFT", +] # "Flow-injection QqQ/MS", + def NCE_to_eV(nce, precursor_mz, charge=1): return nce * precursor_mz / 500 * charge_factor[charge] + def align_CE(ce, precursor_mz, instrument=None): - if type(ce) == float: + if isinstance(ce, float): if ce > 0.0 and ce < 1.0: - return str(ce) # REMOVE + return str(ce) # REMOVE if instrument in nce_instruments: return NCE_to_eV(ce, precursor_mz) return ce @@ -18,62 +26,62 @@ def align_CE(ce, precursor_mz, instrument=None): ce = ce.replace("eV", "") try: return float(ce) - except: + except Exception as _: return ce elif "V" in ce: ce = ce.replace("V", "") try: return float(ce) - except: + except Exception as _: return ce elif "ev" in ce: ce = ce.replace("ev", "") try: return float(ce) - except: + except Exception as _: return ce elif "% (nominal)" in ce: try: - nce = ce.split('% (nominal)')[0].strip().split(' ')[-1] + nce = ce.split("% (nominal)")[0].strip().split(" ")[-1] nce = float(nce) return NCE_to_eV(nce, precursor_mz) - except: + except Exception as _: return ce elif "(nominal)" in ce: try: - nce = ce.split('(nominal)')[0].strip().split(' ')[-1] + nce = ce.split("(nominal)")[0].strip().split(" ")[-1] nce = float(nce) return NCE_to_eV(nce, precursor_mz) - except: + except Exception as _: return ce elif "(NCE)" in ce: try: - nce = ce.strip().split('(NCE)')[0] + nce = ce.strip().split("(NCE)")[0] nce = float(nce) return NCE_to_eV(nce, precursor_mz) - except: + except Exception as _: return ce elif "HCD" in ce: try: - nce = ce.strip().split('HCD')[0] + nce = ce.strip().split("HCD")[0] nce = float(nce) return NCE_to_eV(nce, precursor_mz) - except: + except Exception as _: return ce elif "%" in ce: try: - nce = ce.split('%')[0].strip().split(' ')[-1] + nce = ce.split("%")[0].strip().split(" ")[-1] nce = float(nce) return NCE_to_eV(nce, precursor_mz) - except: + except Exception as _: return ce else: - try: + try: ce = float(ce) if ce > 0.0 and ce < 1.0: - return str(ce) # REMOVE + return str(ce) # REMOVE if instrument in nce_instruments: return NCE_to_eV(ce, precursor_mz) return ce - except: - return ce \ No newline at end of file + except Exception as _: + return ce diff --git a/fiora/MOL/constants.py b/fiora/MOL/constants.py index 3b6c06d..49c1cfe 100644 --- a/fiora/MOL/constants.py +++ b/fiora/MOL/constants.py @@ -1,64 +1,97 @@ from rdkit import Chem from rdkit.Chem import Descriptors -h_minus = Chem.MolFromSmiles("[H-]") #hydrid -h_plus = Chem.MolFromSmiles("[H+]") #h proton -h_2 = Chem.MolFromSmiles("[HH]") #h2 +h_minus = Chem.MolFromSmiles("[H-]") # hydrid +h_plus = Chem.MolFromSmiles("[H+]") # h proton +h_2 = Chem.MolFromSmiles("[HH]") # h2 ADDUCT_WEIGHTS = { - "[M+H]+": Descriptors.ExactMolWt(h_plus), #1.007276, - "[M+H]-": Descriptors.ExactMolWt(h_plus), # TODO might not technically exist + "[M+H]+": Descriptors.ExactMolWt(h_plus), # 1.007276, + "[M+H]-": Descriptors.ExactMolWt(h_plus), # TODO might not technically exist "[M+NH4]+": 18.033823, - "[M+Na]+": 22.989218 , - "[M-H]-": -1*Descriptors.ExactMolWt(h_plus), - + "[M+Na]+": 22.989218, + "[M-H]-": -1 * Descriptors.ExactMolWt(h_plus), # # positvie fragment rearrangements # - "[M-H]+": -1*Descriptors.ExactMolWt(h_minus), # Double bond replacing 2 hydrogen atoms + H + "[M-H]+": -1 + * Descriptors.ExactMolWt(h_minus), # Double bond replacing 2 hydrogen atoms + H "[M]+": 0, - "[M-2H]+": -1 * Descriptors.ExactMolWt(h_2), # Loosing proton and hydrid - "[M-3H]+": -1 * Descriptors.ExactMolWt(h_2) - 1 * Descriptors.ExactMolWt(h_minus), # 2 Double bonds + H + "[M-2H]+": -1 * Descriptors.ExactMolWt(h_2), # Loosing proton and hydrid + "[M-3H]+": -1 * Descriptors.ExactMolWt(h_2) + - 1 * Descriptors.ExactMolWt(h_minus), # 2 Double bonds + H # experimental cases - #"[M-4H]+": -1.007276 * 4, - #"[M-5H]+": -1.007276 * 5, - - + # "[M-4H]+": -1.007276 * 4, + # "[M-5H]+": -1.007276 * 5, # # negative fragment rearrangements - # - + # # "[M-H]-": -1*Chem.Descriptors.ExactMolWt(h_plus), # see above - "[M]-": 0, # could be one electron too many + "[M]-": 0, # could be one electron too many "[M-2H]-": -1 * Descriptors.ExactMolWt(h_2), - "[M-3H]-": -1 * Descriptors.ExactMolWt(h_2) - 1 * Chem.Descriptors.ExactMolWt(h_plus), - + "[M-3H]-": -1 * Descriptors.ExactMolWt(h_2) + - 1 * Chem.Descriptors.ExactMolWt(h_plus), # # Hydrogen gains # - - "[M+2H]+": Descriptors.ExactMolWt(h_plus) + 1 * Descriptors.ExactMolWt(Chem.MolFromSmiles("[H]")), # 1 proton + 1 neutral hydrogens - "[M+3H]+": Descriptors.ExactMolWt(h_plus) + 2 * Descriptors.ExactMolWt(Chem.MolFromSmiles("[H]")), # 1 proton + 2 neutral hydrogens - "[M+2H]-": Descriptors.ExactMolWt(h_plus) + 1 * Descriptors.ExactMolWt(Chem.MolFromSmiles("[H]")), # 1 proton + 2 neutral hydrogens - "[M+3H]-": Descriptors.ExactMolWt(h_plus) + 2 * Descriptors.ExactMolWt(Chem.MolFromSmiles("[H]")), # 1 proton + 2 neutral hydrogens - - } - + "[M+2H]+": Descriptors.ExactMolWt(h_plus) + + 1 + * Descriptors.ExactMolWt( + Chem.MolFromSmiles("[H]") + ), # 1 proton + 1 neutral hydrogens + "[M+3H]+": Descriptors.ExactMolWt(h_plus) + + 2 + * Descriptors.ExactMolWt( + Chem.MolFromSmiles("[H]") + ), # 1 proton + 2 neutral hydrogens + "[M+2H]-": Descriptors.ExactMolWt(h_plus) + + 1 + * Descriptors.ExactMolWt( + Chem.MolFromSmiles("[H]") + ), # 1 proton + 2 neutral hydrogens + "[M+3H]-": Descriptors.ExactMolWt(h_plus) + + 2 + * Descriptors.ExactMolWt( + Chem.MolFromSmiles("[H]") + ), # 1 proton + 2 neutral hydrogens +} -PPM = 1/1000000 +PPM = 1 / 1000000 DEFAULT_PPM = 100 * PPM -DEFAULT_DALTON = 0.05 # equiv: 100ppm of 500 m/z -MIN_ABS_TOLERANCE = 0.01 # 0.02 # Tolerance aplied for small fragment when relative PPM gets too small -#DEFAULT_MODES = ["[M+H]+", "[M-H]+", "[M-3H]+"] -DEFAULT_MODES = ["[M+H]+", "[M]+", "[M-H]+", "[M-2H]+", "[M-3H]+"] #"[M-4H]+"] #, "[M-5H]+"] +DEFAULT_DALTON = 0.05 # equiv: 100ppm of 500 m/z +MIN_ABS_TOLERANCE = ( + 0.01 # 0.02 # Tolerance aplied for small fragment when relative PPM gets too small +) +# DEFAULT_MODES = ["[M+H]+", "[M-H]+", "[M-3H]+"] +DEFAULT_MODES = [ + "[M+H]+", + "[M]+", + "[M-H]+", + "[M-2H]+", + "[M-3H]+", +] # "[M-4H]+"] #, "[M-5H]+"] DEFAULT_MODE_MAP = {mode: i for i, mode in enumerate(DEFAULT_MODES)} -#NEGATIVE_MODES = ["[M]-", "[M-H]-", "[M-2H]-", "[M-3H]-", "[M-4H]-"] +# NEGATIVE_MODES = ["[M]-", "[M-H]-", "[M-2H]-", "[M-3H]-", "[M-4H]-"] # source: https://fiehnlab.ucdavis.edu/staff/kind/metabolomics/ms-adduct-calculator -ORDERED_ELEMENT_LIST = ['Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S'] # Warning: Changes may affect model and version control -ORDERED_ELEMENT_LIST_WITH_HYDROGEN = ORDERED_ELEMENT_LIST + ['H'] # Hydrogen is added at the end for element composition encoding +ORDERED_ELEMENT_LIST = [ + "Br", + "C", + "Cl", + "F", + "I", + "N", + "O", + "P", + "S", +] # Warning: Changes may affect model and version control +ORDERED_ELEMENT_LIST_WITH_HYDROGEN = ORDERED_ELEMENT_LIST + [ + "H" +] # Hydrogen is added at the end for element composition encoding -MAX_SUBGRAPH_NODES = 128 # Maximum number of nodes in a subgraph, used for Batch padding and indexing \ No newline at end of file +MAX_SUBGRAPH_NODES = ( + 128 # Maximum number of nodes in a subgraph, used for Batch padding and indexing +) diff --git a/fiora/MOL/mol_graph.py b/fiora/MOL/mol_graph.py index 40226fe..cd61160 100644 --- a/fiora/MOL/mol_graph.py +++ b/fiora/MOL/mol_graph.py @@ -1,94 +1,102 @@ - import numpy as np -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt import torch import networkx as nx -node_color_map = {'C': 'gray', - 'O': 'red', - 'N': 'blue'} - +node_color_map = {"C": "gray", "O": "red", "N": "blue"} -edge_color_map = {'SINGLE': 'black', - 'DOUBLE': 'black', - 'AROMATIC': 'blue'} - -edge_width_map = {'SINGLE': 1.5, - 'DOUBLE': 3, - 'AROMATIC': 3} +edge_color_map = {"SINGLE": "black", "DOUBLE": "black", "AROMATIC": "blue"} +edge_width_map = {"SINGLE": 1.5, "DOUBLE": 3, "AROMATIC": 3} def mol_to_graph(mol): G = nx.Graph() for atom in mol.GetAtoms(): - color = node_color_map[atom.GetSymbol()] if atom.GetSymbol() in node_color_map.keys() else 'black' - G.add_node(atom.GetIdx(), - atomic_num=atom.GetAtomicNum(), - is_aromatic=atom.GetIsAromatic(), - atom_symbol=atom.GetSymbol(), - color=color, - atom=atom) + color = ( + node_color_map[atom.GetSymbol()] + if atom.GetSymbol() in node_color_map.keys() + else "black" + ) + G.add_node( + atom.GetIdx(), + atomic_num=atom.GetAtomicNum(), + is_aromatic=atom.GetIsAromatic(), + atom_symbol=atom.GetSymbol(), + color=color, + atom=atom, + ) for bond in mol.GetBonds(): - G.add_edge(bond.GetBeginAtomIdx(), - bond.GetEndAtomIdx(), - bond_type=bond.GetBondType(), - bond=bond) + G.add_edge( + bond.GetBeginAtomIdx(), + bond.GetEndAtomIdx(), + bond_type=bond.GetBondType(), + bond=bond, + ) return G + def draw_graph(G, ax=None, edge_labels=False): if not ax: ax = plt.gca() pos = nx.spring_layout(G) - nx.draw(G,ax=ax, pos=pos, - labels=nx.get_node_attributes(G, 'atom_symbol'), - with_labels = True, - node_color=list(nx.get_node_attributes(G, 'color').values()), + nx.draw( + G, + ax=ax, + pos=pos, + labels=nx.get_node_attributes(G, "atom_symbol"), + with_labels=True, + node_color=list(nx.get_node_attributes(G, "color").values()), node_size=800, - #edges=G.edges(), - edge_color=[edge_color_map[G[u][v]["bond_type"].name] for u,v in G.edges], - width=[edge_width_map[G[u][v]["bond_type"].name] for u,v in G.edges], - ) + # edges=G.edges(), + edge_color=[edge_color_map[G[u][v]["bond_type"].name] for u, v in G.edges], + width=[edge_width_map[G[u][v]["bond_type"].name] for u, v in G.edges], + ) if edge_labels: nx.draw_networkx_edge_labels( - G, pos, - edge_labels=dict([((u, v), f'({u}, {v})') for u, v in G.edges]), - font_color='red', - ax=ax - ) + G, + pos, + edge_labels=dict([((u, v), f"({u}, {v})") for u, v in G.edges]), + font_color="red", + ax=ax, + ) plt.show() + def get_adjacency_matrix(G): - #return torch.tensor(nx.convert_matrix.to_numpy_matrix(G), dtype=torch.float32) + # return torch.tensor(nx.convert_matrix.to_numpy_matrix(G), dtype=torch.float32) return torch.tensor(nx.convert_matrix.to_numpy_array(G), dtype=torch.float32) + def get_degree_matrix(A): - #return tf.transpose([tf.clip_by_value(tf.reduce_sum(A, axis=-1), 0.0001, 1000.0)]) + # return tf.transpose([tf.clip_by_value(tf.reduce_sum(A, axis=-1), 0.0001, 1000.0)]) return torch.clamp(torch.sum(A, dim=1, keepdim=True), 0.0001, 1000.0) - + def get_identity_matrix(A): return torch.eye(A.shape[0]) + def get_edges(A): edge_idx = [] - deg=get_degree_matrix(A) + deg = get_degree_matrix(A) row = 0 for j in range(deg.shape[0]): - row_degree = int(deg[j,0].numpy()) + row_degree = int(deg[j, 0].numpy()) for i in range(row_degree): edges_to = np.where(A[j] > 0.001)[0] edge_idx.append((j, edges_to[i])) row += row_degree return edge_idx + def compute_edge_related_helper_matrices(A, deg): AL = torch.zeros(torch.sum(deg).int(), A.shape[0]) AR = torch.zeros(AL.shape) @@ -96,7 +104,7 @@ def compute_edge_related_helper_matrices(A, deg): row = 0 for j in range(deg.shape[0]): - row_degree = int(deg[j,0].numpy()) + row_degree = int(deg[j, 0].numpy()) for i in range(row_degree): edges_to = np.where(A[j] > 0.001)[0] AL[row + i, j] = 1.0 @@ -112,8 +120,8 @@ def get_helper_matrices_from_edges(edges, A): AR = torch.zeros(AL.shape) edge_idx = [] - for i, (u,v) in enumerate(edges): + for i, (u, v) in enumerate(edges): AL[i, u] = 1.0 AR[i, v] = 1.0 - - return AL, AR \ No newline at end of file + + return AL, AR diff --git a/fiora/MS/SimulationFramework.py b/fiora/MS/SimulationFramework.py index c7bfdf2..41c5116 100644 --- a/fiora/MS/SimulationFramework.py +++ b/fiora/MS/SimulationFramework.py @@ -1,211 +1,348 @@ import torch import torch_geometric as geom import pandas as pd +import numpy as np from typing import Literal, Dict -import matplotlib.pyplot as plt - from fiora.MOL.Metabolite import Metabolite -from fiora.MS.spectral_scores import * -import fiora.visualization.spectrum_visualizer as sv +from fiora.MS.spectral_scores import ( + spectral_cosine, + spectral_reflection_cosine, + reweighted_dot, +) from fiora.MOL.constants import DEFAULT_MODE_MAP + class SimulationFramework: - - def __init__(self, base_model: torch.nn.Module|None=None, dev: str="cpu"): + def __init__(self, base_model: torch.nn.Module | None = None, dev: str = "cpu"): self.base_model = base_model self.dev = dev self.mode_map = None def __repr__(self): - return f"Simulation framework for MS/MS spectrum generation" - + return "Simulation framework for MS/MS spectrum generation" + def __str__(self): - return f"Simulation framework for MS/MS spectrum generation" + return "Simulation framework for MS/MS spectrum generation" def set_mode_mapper(self, mode_map): self.mode_map = mode_map - def predict_metabolite_property(self, metabolite, model: torch.nn.Module|None=None, as_batch: bool=False): + def predict_metabolite_property( + self, metabolite, model: torch.nn.Module | None = None, as_batch: bool = False + ): if not model: - model = self.base_model + model = self.base_model data = metabolite.as_geometric_data(with_labels=False).to(self.dev) if as_batch: data = geom.data.Batch.from_data_list([data]) - - logits = model(data, with_RT=hasattr(model, "rt_module"), with_CCS=hasattr(model, "ccs_module")) + + logits = model( + data, + with_RT=hasattr(model, "rt_module"), + with_CCS=hasattr(model, "ccs_module"), + ) return logits - - def pred_all(self, df: pd.DataFrame, model: torch.nn.Module|None=None, attr_name: str="", as_batch: bool=True): + + def pred_all( + self, + df: pd.DataFrame, + model: torch.nn.Module | None = None, + attr_name: str = "", + as_batch: bool = True, + ): with torch.no_grad(): model.eval() - for i,d in df.iterrows(): + for i, d in df.iterrows(): metabolite = d["Metabolite"] - prediction = self.predict_metabolite_property(metabolite, model=model, as_batch=as_batch) + prediction = self.predict_metabolite_property( + metabolite, model=model, as_batch=as_batch + ) if hasattr(model, "rt_module"): - setattr(metabolite, attr_name + "_pred", prediction["fragment_probs"]) + setattr( + metabolite, attr_name + "_pred", prediction["fragment_probs"] + ) setattr(metabolite, "RT_pred", prediction["rt"].squeeze()) else: - setattr(metabolite, attr_name + "_pred", prediction["fragment_probs"]) + setattr( + metabolite, attr_name + "_pred", prediction["fragment_probs"] + ) return - - def simulate_spectrum(self, metabolite: Metabolite, pred_label: str, precursor_mode: Literal["[M+H]+", "[M-H]-"]="[M+H]+", min_intensity: float=0.001, merge_fragment_duplicates: bool=True, transform_prob: str="None"): + def simulate_spectrum( + self, + metabolite: Metabolite, + pred_label: str, + precursor_mode: Literal["[M+H]+", "[M-H]-"] = "[M+H]+", + min_intensity: float = 0.001, + merge_fragment_duplicates: bool = True, + transform_prob: str = "None", + ): if not self.mode_map: mode_map = DEFAULT_MODE_MAP else: mode_map = self.mode_map - + edge_map = metabolite.fragmentation_tree.edge_map - + sim_probs = getattr(metabolite, pred_label) - sim_peaks = {'mz': [], 'intensity': [], 'annotation': []} - + sim_peaks = {"mz": [], "intensity": [], "annotation": []} + precursor_prob = sim_probs[-1].tolist() precursor = edge_map[None] - - sim_peaks["mz"].append(precursor.mz[precursor_mode]) # TODO allow multiple ion modes of precursor + + sim_peaks["mz"].append( + precursor.mz[precursor_mode] + ) # TODO allow multiple ion modes of precursor sim_peaks["intensity"].append(precursor_prob) sim_peaks["annotation"].append(precursor.smiles + "//" + precursor_mode) - edge_probs = sim_probs[:-2].unflatten(-1, sizes=(-1, len(mode_map)*2)) + edge_probs = sim_probs[:-2].unflatten(-1, sizes=(-1, len(mode_map) * 2)) for i, edge in enumerate(metabolite.edges_as_tuples): - if edge[0] > edge[1]: continue # skip backward directions + if edge[0] > edge[1]: + continue # skip backward directions frags = edge_map.get(edge) - if not frags: continue - - lf = frags.get('left') + if not frags: + continue + + lf = frags.get("left") if lf: for mode, idx in mode_map.items(): - intensity = edge_probs[i,idx].tolist() + intensity = edge_probs[i, idx].tolist() if intensity > min_intensity: - mz =lf.mz[mode] - mode_str = mode if precursor_mode=="[M+H]+" else mode.replace("]+", "]-") + mz = lf.mz[mode] + mode_str = ( + mode + if precursor_mode == "[M+H]+" + else mode.replace("]+", "]-") + ) annotation = lf.smiles + "//" + mode_str merged = False - if merge_fragment_duplicates and (mz in sim_peaks["mz"]): # if exact mz value exists already + if merge_fragment_duplicates and ( + mz in sim_peaks["mz"] + ): # if exact mz value exists already for j, mzx in enumerate(sim_peaks["mz"]): - if mz == mzx and annotation == sim_peaks["annotation"][j]: # check mz and annotation - sim_peaks["intensity"][j] += intensity # and intensity if exact same fragments + if ( + mz == mzx + and annotation == sim_peaks["annotation"][j] + ): # check mz and annotation + sim_peaks["intensity"][j] += ( + intensity # and intensity if exact same fragments + ) merged = True break - if merged: continue + if merged: + continue sim_peaks["mz"].append(mz) sim_peaks["intensity"].append(intensity) sim_peaks["annotation"].append(annotation) - rf = frags.get('right') + rf = frags.get("right") if rf: for mode, idx in mode_map.items(): - idx = (idx + len(mode_map)) % (2*len(mode_map)) - intensity = edge_probs[i,idx].tolist() + idx = (idx + len(mode_map)) % (2 * len(mode_map)) + intensity = edge_probs[i, idx].tolist() if intensity > min_intensity: mz = rf.mz[mode] - mode_str = mode if precursor_mode=="[M+H]+" else mode.replace("]+", "]-") + mode_str = ( + mode + if precursor_mode == "[M+H]+" + else mode.replace("]+", "]-") + ) annotation = rf.smiles + "//" + mode_str merged = False if merge_fragment_duplicates and (mz in sim_peaks["mz"]): for j, mzx in enumerate(sim_peaks["mz"]): - if mz == mzx and annotation == sim_peaks["annotation"][j]: - sim_peaks["intensity"][j] += intensity # and intensity if exact same fragments + if ( + mz == mzx + and annotation == sim_peaks["annotation"][j] + ): + sim_peaks["intensity"][j] += ( + intensity # and intensity if exact same fragments + ) merged = True - break - if merged: continue + break + if merged: + continue sim_peaks["mz"].append(mz) sim_peaks["intensity"].append(intensity) - sim_peaks["annotation"].append(annotation) - + sim_peaks["annotation"].append(annotation) + if transform_prob == "square": - max_prob = max(sim_peaks["intensity"])**2 + max_prob = max(sim_peaks["intensity"]) ** 2 for i in range(len(sim_peaks["intensity"])): - sim_peaks["intensity"][i] == sim_peaks["intensity"][i]**2 / max_prob - - + sim_peaks["intensity"][i] == sim_peaks["intensity"][i] ** 2 / max_prob + combined = sorted( - zip(sim_peaks["mz"], sim_peaks["intensity"], sim_peaks["annotation"]), - key=lambda t: t[0], - reverse=True, + zip(sim_peaks["mz"], sim_peaks["intensity"], sim_peaks["annotation"]), + key=lambda t: t[0], + reverse=True, ) mz, inten, annot = zip(*combined) sim_peaks["mz"] = list(mz) sim_peaks["intensity"] = list(inten) sim_peaks["annotation"] = list(annot) - + return sim_peaks - - - - def simulate_and_score(self, metabolite: Metabolite, model: torch.nn.Module|None=None, base_attr_name: str="compiled_probsALL", query_peaks: Dict|None=None, as_batch: bool=True, min_intensity: float=0.001): - prediction = self.predict_metabolite_property(metabolite, model=model, as_batch=as_batch) + + def simulate_and_score( + self, + metabolite: Metabolite, + model: torch.nn.Module | None = None, + base_attr_name: str = "compiled_probsALL", + query_peaks: Dict | None = None, + as_batch: bool = True, + min_intensity: float = 0.001, + ): + prediction = self.predict_metabolite_property( + metabolite, model=model, as_batch=as_batch + ) stats = {} if "rt" in prediction.keys(): stats["RT_pred"] = prediction["rt"].squeeze().tolist() if "ccs" in prediction.keys(): stats["CCS_pred"] = prediction["ccs"].squeeze().tolist() - + setattr(metabolite, base_attr_name + "_pred", prediction["fragment_probs"]) - transform_prob = "square" if ("training_label" in model.model_params and model.model_params["training_label"] == "compiled_probsSQRT") else "None" - stats["sim_peaks"] = self.simulate_spectrum(metabolite, base_attr_name + "_pred", precursor_mode=metabolite.metadata["precursor_mode"], transform_prob=transform_prob, min_intensity=min_intensity) - - + transform_prob = ( + "square" + if ( + "training_label" in model.model_params + and model.model_params["training_label"] == "compiled_probsSQRT" + ) + else "None" + ) + stats["sim_peaks"] = self.simulate_spectrum( + metabolite, + base_attr_name + "_pred", + precursor_mode=metabolite.metadata["precursor_mode"], + transform_prob=transform_prob, + min_intensity=min_intensity, + ) + # Score performance if groundtruth is available if hasattr(metabolite, base_attr_name): groundtruth = getattr(metabolite, base_attr_name).to(self.dev) - - stats["cosine_similarity"] = torch.nn.functional.cosine_similarity(prediction["fragment_probs"], groundtruth, dim=0).tolist() # TODO - stats["kl_div"] = torch.nn.functional.kl_div(torch.log(prediction["fragment_probs"]), groundtruth, reduction='sum').tolist() - + + stats["cosine_similarity"] = torch.nn.functional.cosine_similarity( + prediction["fragment_probs"], groundtruth, dim=0 + ).tolist() # TODO + stats["kl_div"] = torch.nn.functional.kl_div( + torch.log(prediction["fragment_probs"]), groundtruth, reduction="sum" + ).tolist() + if "RT_pred" in stats.keys() and "retention_time" in metabolite.metadata.keys(): - stats["RT_dif"] = abs(stats["RT_pred"] - metabolite.metadata["retention_time"]) - + stats["RT_dif"] = abs( + stats["RT_pred"] - metabolite.metadata["retention_time"] + ) + if query_peaks: - stats["spectral_cosine"], stats["spectral_bias"] = spectral_cosine(query_peaks, stats["sim_peaks"], with_bias=True) - stats["spectral_sqrt_cosine"], stats["spectral_sqrt_bias"] = spectral_cosine(query_peaks, stats["sim_peaks"], transform=np.sqrt, with_bias=True) - stats["spectral_sqrt_cosine_wo_prec"], stats["spectral_sqrt_bias_wo_prec"] = spectral_cosine(query_peaks, stats["sim_peaks"], transform=np.sqrt, remove_mz=metabolite.get_theoretical_precursor_mz(ion_type=metabolite.metadata["precursor_mode"]), with_bias=True) - stats["spectral_sqrt_cosine_avg"], stats["spectral_sqrt_bias_avg"] = (stats["spectral_sqrt_cosine"] + stats["spectral_sqrt_cosine_wo_prec"]) / 2.0, (stats["spectral_sqrt_bias"] + stats["spectral_sqrt_bias_wo_prec"]) / 2.0 - stats["spectral_refl_cosine"], stats["spectral_refl_bias"] = spectral_reflection_cosine(query_peaks, stats["sim_peaks"], transform=np.sqrt, with_bias=True) - stats["steins_cosine"], stats["steins_bias"] = reweighted_dot(query_peaks, stats["sim_peaks"], int_pow=0.5, mz_pow=0.5, with_bias=True) + stats["spectral_cosine"], stats["spectral_bias"] = spectral_cosine( + query_peaks, stats["sim_peaks"], with_bias=True + ) + stats["spectral_sqrt_cosine"], stats["spectral_sqrt_bias"] = ( + spectral_cosine( + query_peaks, stats["sim_peaks"], transform=np.sqrt, with_bias=True + ) + ) + ( + stats["spectral_sqrt_cosine_wo_prec"], + stats["spectral_sqrt_bias_wo_prec"], + ) = spectral_cosine( + query_peaks, + stats["sim_peaks"], + transform=np.sqrt, + remove_mz=metabolite.get_theoretical_precursor_mz( + ion_type=metabolite.metadata["precursor_mode"] + ), + with_bias=True, + ) + stats["spectral_sqrt_cosine_avg"], stats["spectral_sqrt_bias_avg"] = ( + (stats["spectral_sqrt_cosine"] + stats["spectral_sqrt_cosine_wo_prec"]) + / 2.0, + (stats["spectral_sqrt_bias"] + stats["spectral_sqrt_bias_wo_prec"]) + / 2.0, + ) + stats["spectral_refl_cosine"], stats["spectral_refl_bias"] = ( + spectral_reflection_cosine( + query_peaks, stats["sim_peaks"], transform=np.sqrt, with_bias=True + ) + ) + stats["steins_cosine"], stats["steins_bias"] = reweighted_dot( + query_peaks, stats["sim_peaks"], int_pow=0.5, mz_pow=0.5, with_bias=True + ) return stats - - def simulate_all(self, df: pd.DataFrame, model: torch.nn.Module|None=None, base_attr_name: str="compiled_probsALL", suffix: str="", groundtruth=True, min_intensity: float=0.001): - + + def simulate_all( + self, + df: pd.DataFrame, + model: torch.nn.Module | None = None, + base_attr_name: str = "compiled_probsALL", + suffix: str = "", + groundtruth=True, + min_intensity: float = 0.001, + ): + with torch.no_grad(): model.eval() - for i,data in df.iterrows(): + for i, data in df.iterrows(): metabolite = data["Metabolite"] - stats = self.simulate_and_score(metabolite, model, base_attr_name, query_peaks=data["peaks"] if groundtruth else None, min_intensity=min_intensity) - df = pd.concat([df, pd.DataFrame(columns=[x + suffix for x in stats.keys()])]) # Add new empty columns for all statistics - + stats = self.simulate_and_score( + metabolite, + model, + base_attr_name, + query_peaks=data["peaks"] if groundtruth else None, + min_intensity=min_intensity, + ) + df = pd.concat( + [df, pd.DataFrame(columns=[x + suffix for x in stats.keys()])] + ) # Add new empty columns for all statistics + for key, value in stats.items(): if key + suffix in df.columns: df.at[i, key + suffix] = value setattr(metabolite, key + suffix, value) else: - raise Warning("User Warning: Attempting to add data to non-existing column simulate_all().\n\tSolve by adding column with pd.concat()") - + raise Warning( + "User Warning: Attempting to add data to non-existing column simulate_all().\n\tSolve by adding column with pd.concat()" + ) + return df - - - def plot_feature_prediction_vectors(self, metabolite, label, with_mol=True, transform=None): - + + def plot_feature_prediction_vectors( + self, metabolite, label, with_mol=True, transform=None + ): + + import matplotlib.pyplot as plt + import fiora.visualization.spectrum_visualizer as sv + if with_mol: - fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False) - img = metabolite.draw(ax=axs[0]) - axs1 = axs[1] + _, axs = plt.subplots( + 1, + 2, + figsize=(12.8, 4.2), + gridspec_kw={"width_ratios": [1, 3]}, + sharey=False, + ) + _ = metabolite.draw(ax=axs[0]) else: - fig, axs1 = plt.subplots(1, 1, figsize=(12.8, 4.2)) - - relevant_edge_index = torch.logical_and(metabolite.compiled_validation_mask, metabolite.compiled_forward_mask) - probs = getattr(metabolite, label).to(dev)[relevant_edge_index] - preds = getattr(metabolite, 'predicted_' + label).to(self.dev)[relevant_edge_index] - + _, _ = plt.subplots(1, 1, figsize=(12.8, 4.2)) + + relevant_edge_index = torch.logical_and( + metabolite.compiled_validation_mask, metabolite.compiled_forward_mask + ) + probs = getattr(metabolite, label).to(self.dev)[relevant_edge_index] + preds = getattr(metabolite, "predicted_" + label).to(self.dev)[ + relevant_edge_index + ] + names = [f"e{i}" for i in range(preds.shape[0] - 1)] + ["prec"] - ax = sv.plot_vector_spectrum(probs.tolist(), preds.tolist(), ax=axs[1], names=names) + _ = sv.plot_vector_spectrum( + probs.tolist(), preds.tolist(), ax=axs[1], names=names + ) plt.show() - - diff --git a/fiora/MS/ms_utility.py b/fiora/MS/ms_utility.py index 0726cfe..f1b4db3 100644 --- a/fiora/MS/ms_utility.py +++ b/fiora/MS/ms_utility.py @@ -1,57 +1,65 @@ -#from modules.MOL.FragmentationTree import FragmentationTree +# from modules.MOL.FragmentationTree import FragmentationTree from fiora.MOL.constants import PPM, DEFAULT_PPM, MIN_ABS_TOLERANCE from typing import Literal import numpy as np import copy -def do_mz_values_match(mz, mz_other, tolerance, in_ppm=True, require_minimum_tolerance=True): - if in_ppm: tolerance = tolerance * mz - if tolerance 0: multiples.append((mz, ff)) elif n == 0: unidentified.append(mz) - + return uniques, multiples, unidentified -def normalize_spectrum(spec, type: Literal["max_intensity", "norm"]="norm"): - if type=="max_intensity": +def normalize_spectrum(spec, type: Literal["max_intensity", "norm"] = "norm"): + if type == "max_intensity": maximum = max(spec["intensity"]) spec["intensity"] = [i / maximum for i in spec["intensity"]] - elif type=="norm": - spec["intensity"] = list(np.array(spec["intensity"]) / np.linalg.norm(spec["intensity"]) ) + elif type == "norm": + spec["intensity"] = list( + np.array(spec["intensity"]) / np.linalg.norm(spec["intensity"]) + ) else: raise ValueError("Unknown type of normalization") def merge_annotated_spectrum(spec1, spec2): spec1 = copy.deepcopy(spec1) - spec2_red = {"mz": [], 'intensity': [], 'annotation': []} + spec2_red = {"mz": [], "intensity": [], "annotation": []} for i, mz2 in enumerate(spec2["mz"]): merged_peak = False - if (mz2 in spec1["mz"]): + if mz2 in spec1["mz"]: for j, mz1 in enumerate(spec1["mz"]): if mz1 == mz2 and spec1["annotation"][j] == spec2["annotation"][i]: spec1["intensity"][j] += spec2["intensity"][i] @@ -61,24 +69,24 @@ def merge_annotated_spectrum(spec1, spec2): spec2_red["mz"] += [spec2["mz"][i]] spec2_red["intensity"] += [spec2["intensity"][i]] spec2_red["annotation"] += [spec2["annotation"][i]] - + spec1["mz"] += spec2_red["mz"] spec1["intensity"] += spec2_red["intensity"] spec1["annotation"] += spec2_red["annotation"] - - return spec1 - + return spec1 -def merge_spectrum(spec1, spec2, merge_tolerance: float=0.0): +def merge_spectrum(spec1, spec2, merge_tolerance: float = 0.0): spec1 = copy.deepcopy(spec1) if merge_tolerance > 0.01: - raise Warning("Merging peaks recommended only for very small tolerances. Peak merging has mainly a visual impact is not needed for computation.") - spec2_red = {"mz": [], 'intensity': []} + raise Warning( + "Merging peaks recommended only for very small tolerances. Peak merging has mainly a visual impact is not needed for computation." + ) + spec2_red = {"mz": [], "intensity": []} for i, mz2 in enumerate(spec2["mz"]): merged_peak = False - + for j, mz1 in enumerate(spec1["mz"]): if abs(mz1 - mz2) <= merge_tolerance: spec1["intensity"][j] += spec2["intensity"][i] @@ -88,14 +96,14 @@ def merge_spectrum(spec1, spec2, merge_tolerance: float=0.0): if not merged_peak: spec2_red["mz"] += [spec2["mz"][i]] spec2_red["intensity"] += [spec2["intensity"][i]] - + spec1["mz"] += spec2_red["mz"] spec1["intensity"] += spec2_red["intensity"] - + return spec1 -''' +""" def match_first_order_peaks(df, offset=0): @@ -126,4 +134,4 @@ def match_first_order_peaks(df, offset=0): percentage.append(c_percentage) return peaks, unique, percentage -''' \ No newline at end of file +""" diff --git a/fiora/MS/spectral_scores.py b/fiora/MS/spectral_scores.py index e0af147..e1e5a4d 100644 --- a/fiora/MS/spectral_scores.py +++ b/fiora/MS/spectral_scores.py @@ -1,55 +1,72 @@ import numpy as np -from typing import Literal, Dict from fiora.MOL.constants import DEFAULT_DALTON + def cosine(vec, vec_other): return np.dot(vec, vec_other) / (np.linalg.norm(vec) * np.linalg.norm(vec_other)) + def cosine_bias_alt(vec, vec_other, cosine_precomputed=None): if not cosine_precomputed: cosine_precomputed = cosine(vec, vec_other) - + if cosine_precomputed <= 0.0: return 1.0 - bias = np.sqrt(np.square(np.dot(vec, vec_other)) / np.square(np.linalg.norm(vec) * np.linalg.norm(vec_other))) + bias = np.sqrt( + np.square(np.dot(vec, vec_other)) + / np.square(np.linalg.norm(vec) * np.linalg.norm(vec_other)) + ) return bias / cosine_precomputed def cosine_bias(vec, vec_other, cosine_precomputed=None): if not cosine_precomputed: cosine_precomputed = cosine(vec, vec_other) - + if cosine_precomputed <= 0.0: return 1.0 vec = vec / np.linalg.norm(vec) vec_other = vec_other / np.linalg.norm(vec_other) - bias = np.sqrt(np.dot(np.square(vec), np.square(vec_other))) # wrong: bias = np.sqrt(np.square(np.dot(vec, vec_other)) / (np.square(np.linalg.norm(vec)) * np.square(np.linalg.norm(vec_other)))) + bias = np.sqrt( + np.dot(np.square(vec), np.square(vec_other)) + ) # wrong: bias = np.sqrt(np.square(np.dot(vec, vec_other)) / (np.square(np.linalg.norm(vec)) * np.square(np.linalg.norm(vec_other)))) return bias / cosine_precomputed + def create_mz_map(mz, mz_other, tolerance): mz_unique = np.sort(np.unique(mz + mz_other)) bin_map = dict(zip(mz_unique, range(len(mz_unique)))) - + for i, mz in enumerate(mz_unique[:-1]): - if abs(mz_unique[i+1] - mz) < tolerance: #TODO this may lead to continuous mz matches that lie >tolerance apart on the edges - bin_map[mz_unique[i+1]] = bin_map[mz] #TODO remove unmapped bins + if ( + abs(mz_unique[i + 1] - mz) < tolerance + ): # TODO this may lead to continuous mz matches that lie >tolerance apart on the edges + bin_map[mz_unique[i + 1]] = bin_map[mz] # TODO remove unmapped bins return bin_map -def spectral_cosine(spec, spec_ref, tolerance=DEFAULT_DALTON, transform=None, with_bias=False, remove_mz: float|None=None): + +def spectral_cosine( + spec, + spec_ref, + tolerance=DEFAULT_DALTON, + transform=None, + with_bias=False, + remove_mz: float | None = None, +): mz_map = create_mz_map(spec["mz"], spec_ref["mz"], tolerance=tolerance) - vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) - + vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) + bins = list(map(mz_map.get, spec["mz"])) bins_ref = list(map(mz_map.get, spec_ref["mz"])) - - np.add.at(vec, bins, spec["intensity"]) #vec.put(bins, spec["intensity"]) - np.add.at(vec_ref, bins_ref, spec_ref["intensity"]) - # zero out specific mz value, e.g. precursor m/z + np.add.at(vec, bins, spec["intensity"]) # vec.put(bins, spec["intensity"]) + np.add.at(vec_ref, bins_ref, spec_ref["intensity"]) + + # zero out specific mz value, e.g. precursor m/z if remove_mz: bin = None for mz in mz_map.keys(): @@ -59,11 +76,10 @@ def spectral_cosine(spec, spec_ref, tolerance=DEFAULT_DALTON, transform=None, wi if bin: vec[bin] = 0 vec_ref[bin] = 0 - if transform: - vec=transform(vec) - vec_ref=transform(vec_ref) + vec = transform(vec) + vec_ref = transform(vec_ref) cos = cosine(vec, vec_ref) if with_bias: @@ -71,24 +87,27 @@ def spectral_cosine(spec, spec_ref, tolerance=DEFAULT_DALTON, transform=None, wi return cos, bias return cos -def spectral_reflection_cosine(spec, spec_ref, tolerance=DEFAULT_DALTON, transform=None, with_bias=False): + +def spectral_reflection_cosine( + spec, spec_ref, tolerance=DEFAULT_DALTON, transform=None, with_bias=False +): mz_map = create_mz_map(spec["mz"], spec_ref["mz"], tolerance=tolerance) - vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) - + vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) + bins = list(map(mz_map.get, spec["mz"])) bins_ref = list(map(mz_map.get, spec_ref["mz"])) - - np.add.at(vec, bins, spec["intensity"]) #vec.put(bins, spec["intensity"]) - np.add.at(vec_ref, bins_ref, spec_ref["intensity"]) - - #Reflection score: Remove values that are not matched with the reference values - unmatched_bins = [b for b in bins if b not in bins_ref] - vec.put(unmatched_bins, 0.) - + + np.add.at(vec, bins, spec["intensity"]) # vec.put(bins, spec["intensity"]) + np.add.at(vec_ref, bins_ref, spec_ref["intensity"]) + + # Reflection score: Remove values that are not matched with the reference values + unmatched_bins = [b for b in bins if b not in bins_ref] + vec.put(unmatched_bins, 0.0) + if transform: - vec=transform(vec) - vec_ref=transform(vec_ref) - + vec = transform(vec) + vec_ref = transform(vec_ref) + cos = cosine(vec, vec_ref) if with_bias: bias = cosine_bias(vec, vec_ref, cosine_precomputed=cos) @@ -96,35 +115,30 @@ def spectral_reflection_cosine(spec, spec_ref, tolerance=DEFAULT_DALTON, transfo return cos - - -def create_mz_map(mz, mz_other, tolerance): - mz_unique = np.sort(np.unique(mz + mz_other)) - bin_map = dict(zip(mz_unique, range(len(mz_unique)))) - - for i, mz in enumerate(mz_unique[:-1]): - if (mz_unique[i+1] - mz) < tolerance: #TODO this may lead to continuous mz matches that lie >tolerance apart on the edges - bin_map[mz_unique[i+1]] = bin_map[mz] #TODO remove unmapped bins - - return bin_map - #### TODO TEST THIS -def reweighted_dot(spec, spec_ref, int_pow=0.5, mz_pow=0.5, tolerance=DEFAULT_DALTON, with_bias=False): +def reweighted_dot( + spec, spec_ref, int_pow=0.5, mz_pow=0.5, tolerance=DEFAULT_DALTON, with_bias=False +): mz_map = create_mz_map(spec["mz"], spec_ref["mz"], tolerance=tolerance) - vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) - + vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) + bins = list(map(mz_map.get, spec["mz"])) bins_ref = list(map(mz_map.get, spec_ref["mz"])) - - spec["mz_int"] = [np.power(spec["intensity"][i], int_pow)*np.power(mz, mz_pow) for i, mz in enumerate(spec["mz"])] - spec_ref["mz_int"] = [np.power(spec_ref["intensity"][i], int_pow)*np.power(mz, mz_pow) for i, mz in enumerate(spec_ref["mz"])] - + + spec["mz_int"] = [ + np.power(spec["intensity"][i], int_pow) * np.power(mz, mz_pow) + for i, mz in enumerate(spec["mz"]) + ] + spec_ref["mz_int"] = [ + np.power(spec_ref["intensity"][i], int_pow) * np.power(mz, mz_pow) + for i, mz in enumerate(spec_ref["mz"]) + ] + np.add.at(vec, bins, spec["mz_int"]) np.add.at(vec_ref, bins_ref, spec_ref["mz_int"]) - cos = cosine(vec, vec_ref) if with_bias: bias = cosine_bias(vec, vec_ref, cosine_precomputed=cos) return cos, bias - return cos \ No newline at end of file + return cos diff --git a/fiora/cli/__init__.py b/fiora/cli/__init__.py new file mode 100644 index 0000000..35fc369 --- /dev/null +++ b/fiora/cli/__init__.py @@ -0,0 +1 @@ +# CLI package for fiora entry points. diff --git a/fiora/cli/predict.py b/fiora/cli/predict.py new file mode 100644 index 0000000..c230f80 --- /dev/null +++ b/fiora/cli/predict.py @@ -0,0 +1,362 @@ +#! /usr/bin/env python +import argparse +import importlib.resources as resources +import warnings + +warnings.filterwarnings("ignore", category=SyntaxWarning) + +import pandas as pd +from rdkit import RDLogger + +import fiora.IO.mgfWriter as mgfWriter +import fiora.IO.mspWriter as mspWriter +from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder +from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder +from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder +from fiora.GNN.FioraModel import FioraModel +from fiora.MOL.Metabolite import Metabolite +from fiora.MS.SimulationFramework import SimulationFramework + +RDLogger.DisableLog("rdApp.*") +warnings.filterwarnings( + "ignore", category=UserWarning, message="TypedStorage is deprecated" +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + prog="fiora-predict", + description=( + "Fiora is an in silico fragmentation framework, which predicts peaks and " + "simulates tandem mass spectra including features such as retention time " + "and collision cross sections. Use this script for spectrum predictions " + "with a (pre-)trained model." + ), + epilog="Disclaimer:\nNo prediction software is perfect. Use with caution.", + ) + parser.add_argument( + "-i", + "--input", + help="Input file containing molecular structures (SMILES/InChi) and metadata (.csv file)", + type=str, + required=True, + ) + parser.add_argument( + "-o", + "--output", + help="Output file path (.mgf/.msp file)", + type=str, + required=True, + ) + parser.add_argument( + "--model", + help="Path to prediction model (.pt file)", + type=str, + default="default", + ) + parser.add_argument( + "--dev", + help="Device to the model. For example cuda:0 for GPU number 0.", + type=str, + default="cpu", + ) + parser.add_argument( + "--min_prob", + help="Minimum peak probability to be recorded in the spectrum", + type=float, + default=0.001, + ) + + parser.add_argument( + "--rt", + action=argparse.BooleanOptionalAction, + help="Predict retention time", + default=False, + ) + parser.add_argument( + "--ccs", + action=argparse.BooleanOptionalAction, + help="Predict collison cross section", + default=False, + ) + parser.add_argument( + "--annotation", + action=argparse.BooleanOptionalAction, + help="Annotate predicted peaks with SMILES strings", + default=False, + ) + parser.add_argument( + "--debug", + action=argparse.BooleanOptionalAction, + help="Receive debug information", + default=False, + ) + return parser.parse_args() + + +def update_args_with_model_params( + args: argparse.Namespace, model_params: dict +) -> argparse.Namespace: + if "rt_supported" in model_params.keys(): + if not model_params["rt_supported"] and args.rt: + print( + "Warning: RT prediction is not support by the model. Overwriting user argument to --no-rt.\n" + ) + args.rt = False + if "ccs_supported" in model_params.keys(): + if not model_params["ccs_supported"] and args.ccs: + print( + "Warning: CCS prediction is not support by the model. Overwriting user argument to --no-ccs.\n" + ) + args.ccs = False + return args + + +def print_model_messages(model_params: dict) -> None: + if "version" in model_params.keys(): + print("\n-----Model-----") + print(model_params["version"]) + print("---------------") + if "disclaimer" in model_params.keys(): + dis_msg = model_params["disclaimer"] + print(f"\nDisclaimer: {dis_msg}") + + +metadata_key_map = { + "name": "Name", + "collision_energy": "CE", + "instrument": "Instrument_type", + "precursor_mode": "Precursor_type", +} + + +def safe_metabolite_creation(smiles): + try: + return Metabolite(smiles) + except (AssertionError, ValueError): + return None + + +def build_metabolites(df: pd.DataFrame, model_params: dict): + # Set feature encoder up + ce_upper_limit = 100.0 + weight_upper_limit = 1000.0 + + model_setup_feature_sets = None + if "setup_features_categorical_set" in model_params.keys(): + model_setup_feature_sets = model_params["setup_features_categorical_set"] + + node_encoder = AtomFeatureEncoder( + feature_list=["symbol", "num_hydrogen", "ring_type"] + ) + bond_encoder = BondFeatureEncoder(feature_list=["bond_type", "ring_type"]) + if model_params["version_number"] == "0.1.0": + covariate_features = [ + "collision_energy", + "molecular_weight", + "precursor_mode", + "instrument", + ] + else: + covariate_features = [ + "collision_energy", + "molecular_weight", + "precursor_mode", + "instrument", + "element_composition", + ] + setup_encoder = CovariateFeatureEncoder( + feature_list=covariate_features, sets_overwrite=model_setup_feature_sets + ) + rt_encoder = CovariateFeatureEncoder( + feature_list=["molecular_weight", "precursor_mode", "instrument"], + sets_overwrite=model_setup_feature_sets, + ) + + setup_encoder.normalize_features["collision_energy"]["max"] = ce_upper_limit + setup_encoder.normalize_features["molecular_weight"]["max"] = weight_upper_limit + rt_encoder.normalize_features["molecular_weight"]["max"] = weight_upper_limit + + # Convert SMILES to Metabolites and create structure graphs and fragmentation trees + df["Metabolite"] = df["SMILES"].apply(safe_metabolite_creation) + invalid_df = df[df["Metabolite"].isna()][["Name", "SMILES"]] + df.dropna(subset=["Metabolite"], inplace=True) + + df["Metabolite"].apply(lambda x: x.create_molecular_structure_graph()) + df["Metabolite"].apply( + lambda x: x.compute_graph_attributes(node_encoder, bond_encoder) + ) + + # Map covariate features to dedicated format and encode + df["summary"] = df.apply( + lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1 + ) + df.apply( + lambda x: x["Metabolite"].add_metadata(x["summary"], setup_encoder, rt_encoder), + axis=1, + ) + + # Fragment compounds + df["Metabolite"].apply(lambda x: x.fragment_MOL(depth=1)) + return df, invalid_df + + +def prepare_output(args, df, model): + df["peaks"] = df["sim_peaks"] + df["Formula"] = df["Metabolite"].apply(lambda x: x.Formula) + df["Precursor_MZ"] = df["Metabolite"].apply( + lambda x: x.get_theoretical_precursor_mz(ion_type=x.metadata["precursor_mode"]) + ) + + # Rename certain columns + if "RT_pred" in df.columns: + df["RETENTIONTIME"] = df["RT_pred"] + df["PRECURSOR_MZ"] = df["Precursor_MZ"] + df["FORMULA"] = df["Formula"] + if "CCS_pred" in df.columns: + df["CCS"] = df["CCS_pred"] + version = ( + model.model_params["version"] + if "version" in model.model_params + else "(pre-release version v0.0.0)" + ) + df["Comment"] = f'"In silico generated spectrum by {version}"' + df["COMMENT"] = df["Comment"] + + # Write output file + if args.output.endswith(".msp"): + df["Collision_energy"] = df["CE"] + headers = [ + "Name", + "SMILES", + "Formula", + "Precursor_MZ", + "Precursor_type", + "Instrument_type", + "Collision_energy", + ] + if args.rt: + headers.append("RETENTIONTIME") + if args.ccs: + headers.append("CCS") + headers.append("Comment") + mspWriter.write_msp( + df, + path=args.output, + write_header=True, + headers=headers, + annotation=args.annotation, + ) + elif args.output.endswith(".mgf"): + headers = [ + "TITLE", + "SMILES", + "FORMULA", + "PRECURSOR_MZ", + "PRECURSORTYPE", + "COLLISIONENERGY", + "INSTRUMENTTYPE", + ] + if args.rt: + headers.append("RETENTIONTIME") + if args.ccs: + headers.append("CCS") + headers.append("COMMENT") + mgfWriter.write_mgf( + df, + path=args.output, + write_header=True, + headers=headers, + header_map={ + "TITLE": "Name", + "PRECURSORTYPE": "Precursor_type", + "INSTRUMENTTYPE": "Instrument_type", + "COLLISIONENERGY": "CE", + }, + annotation=args.annotation, + ) + else: + print( + f"Warning: Unknown output format {args.output}. Writing results to {args.output}.mgf instead." + ) + args.output = args.output + ".mgf" + headers = [ + "TITLE", + "SMILES", + "FORMULA", + "PRECURSORTYPE", + "COLLISIONENERGY", + "INSTRUMENTTYPE", + ] + if args.rt: + headers.append("RETENTIONTIME") + if args.ccs: + headers.append("CCS") + headers.append("COMMENT") + mgfWriter.write_mgf( + df, + path=args.output, + write_header=True, + headers=headers, + header_map={ + "TITLE": "Name", + "PRECURSORTYPE": "Precursor_type", + "INSTRUMENTTYPE": "Instrument_type", + "COLLISIONENERGY": "CE", + }, + annotation=args.annotation, + ) + + +def main() -> None: + args = parse_args() + if args.debug: + print(f"Running fiora prediction with the following parameters: {args}\n") + + # Load model + if args.model == "default": + with resources.as_file( + resources.files("models").joinpath("fiora_OS_v1.0.0.pt") + ) as model_path: + args.model = str(model_path) + + try: + model = FioraModel.load_from_state_dict(args.model) + except Exception as exc: + raise SystemExit( + f"Error: Failed loading from model from state dict. Caused by: {exc}." + ) + + print_model_messages(model.model_params) + args = update_args_with_model_params(args, model.model_params) + + model.eval() + model = model.to(args.dev) + + # Set up Fiora + fiora = SimulationFramework(None, dev=args.dev) + + # Load the data + df = pd.read_csv(args.input) + + # Construct molecular structure graphs and fragmentation trees + df, invalid_df = build_metabolites(df, model.model_params) + if invalid_df.shape[0] > 0: + if args.debug: + print("Warning: The following input SMILES could not be read or formatted:") + print(invalid_df) + else: + print( + "Warning: Some SMILES could not be read or formatted. Run with --debug flag for more information." + ) + + # Simulate compound fragmentation + df = fiora.simulate_all(df, model, groundtruth=False, min_intensity=args.min_prob) + + # Prepare Output + prepare_output(args, df, model) + print(f"Finished prediction. Exported MS/MS spectra to {args.output}.") + + +if __name__ == "__main__": + main() diff --git a/fiora/cli/train.py b/fiora/cli/train.py new file mode 100644 index 0000000..7950e99 --- /dev/null +++ b/fiora/cli/train.py @@ -0,0 +1,539 @@ +#! /usr/bin/env python +import argparse +import ast +import json +import os +import warnings + +import numpy as np +import pandas as pd +import torch +from rdkit import RDLogger + +from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder +from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder +from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder +from fiora.GNN.FioraModel import FioraModel +from fiora.GNN.Losses import ( + GraphwiseKLLoss, + GraphwiseKLLossMetric, + WeightedMAELoss, + WeightedMAEMetric, + WeightedMSELoss, + WeightedMSEMetric, +) +from fiora.GNN.SpectralTrainer import SpectralTrainer +from fiora.IO.LibraryLoader import LibraryLoader +from fiora.MOL.Metabolite import Metabolite +from fiora.MOL.MetaboliteIndex import MetaboliteIndex +from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM + +RDLogger.DisableLog("rdApp.*") +warnings.filterwarnings("ignore", category=SyntaxWarning) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + prog="fiora-train", + description="Train a FIORA model from a preprocessed library CSV.", + ) + parser.add_argument( + "-i", + "--input", + required=True, + help="Path to preprocessed CSV containing spectra, metadata, and SMILES.", + ) + parser.add_argument( + "-o", + "--output", + default="checkpoint_fiora.best.pt", + help="Output path for best checkpoint (.pt).", + ) + parser.add_argument( + "--model-params", + help="Optional path to a JSON file with base model parameters.", + default=None, + ) + parser.add_argument( + "--resume", + help="Optional path to a checkpoint to resume from (.pt).", + default=None, + ) + parser.add_argument( + "--device", + default="auto", + help="Device to run on (e.g. cpu, cuda:0). Default: auto.", + ) + parser.add_argument("--epochs", type=int, default=300) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--learning-rate", type=float, default=2e-4) + parser.add_argument("--weight-decay", type=float, default=1e-5) + parser.add_argument( + "--loss", + choices=["graphwise_kl", "weighted_mse", "weighted_mae", "mse"], + default="graphwise_kl", + ) + parser.add_argument( + "--y-label", + default="compiled_probsALL", + help="Label to use as training target.", + ) + parser.add_argument( + "--with-rt", + action=argparse.BooleanOptionalAction, + default=False, + help="Train RT head if available.", + ) + parser.add_argument( + "--with-ccs", + action=argparse.BooleanOptionalAction, + default=False, + help="Train CCS head if available.", + ) + parser.add_argument("--train-val-split", type=float, default=0.8) + parser.add_argument( + "--split-by-group", + action=argparse.BooleanOptionalAction, + default=True, + help="Split train/val by group_id (prevents leakage).", + ) + parser.add_argument("--group-id-col", default="group_id") + parser.add_argument("--datasplit-col", default="datasplit") + parser.add_argument("--train-label", default="training") + parser.add_argument("--val-label", default="validation") + parser.add_argument("--min-peak-matches", type=int, default=2) + parser.add_argument( + "--ppm", + type=float, + default=None, + help="Default ppm tolerance if column missing.", + ) + parser.add_argument("--ppm-col", default="ppm_peak_tolerance") + parser.add_argument("--summary-col", default="summary") + parser.add_argument("--peaks-col", default="peaks") + parser.add_argument("--smiles-col", default="SMILES") + parser.add_argument("--loss-weight-col", default="loss_weight") + parser.add_argument("--max-rows", type=int, default=None) + parser.add_argument("--fragmentation-depth", type=int, default=1) + parser.add_argument( + "--use-frag-index", + action=argparse.BooleanOptionalAction, + default=True, + help="Use MetaboliteIndex to cache fragmentation trees.", + ) + parser.add_argument( + "--graph-mismatch-policy", + choices=["recompute", "ignore"], + default="recompute", + ) + parser.add_argument( + "--precursor-modes", + default=None, + help="Comma-separated precursor modes to encode.", + ) + parser.add_argument( + "--instruments", + default=None, + help="Comma-separated instrument types to encode.", + ) + parser.add_argument("--ce-upper-limit", type=float, default=100.0) + parser.add_argument("--weight-upper-limit", type=float, default=1000.0) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--num-workers", type=int, default=0) + parser.add_argument("--val-every", type=int, default=1) + parser.add_argument( + "--use-validation-mask", + action=argparse.BooleanOptionalAction, + default=False, + help="Use validation mask during validation.", + ) + parser.add_argument("--validation-mask-name", default="validation_mask") + parser.add_argument( + "--scheduler", + choices=["plateau", "none"], + default="plateau", + ) + parser.add_argument("--scheduler-patience", type=int, default=8) + parser.add_argument("--scheduler-factor", type=float, default=0.5) + parser.add_argument( + "--rt-metric", + action=argparse.BooleanOptionalAction, + default=False, + help="Track RT/CCS metrics instead of fragment metrics.", + ) + parser.add_argument( + "--index-col", + type=int, + default=0, + help="CSV index column (default: 0). Use --no-index-col to disable.", + ) + parser.add_argument( + "--no-index-col", + action="store_true", + help="Disable index_col when reading CSV.", + ) + return parser.parse_args() + + +def _parse_dict(val): + if isinstance(val, dict): + return val + if val is None or (isinstance(val, float) and np.isnan(val)): + return None + text = str(val) + try: + return ast.literal_eval(text.replace("nan", "None")) + except Exception: + return None + + +def _parse_dict_columns(df: pd.DataFrame, columns: list[str]) -> pd.DataFrame: + for col in columns: + if col in df.columns: + df[col] = df[col].apply(_parse_dict) + return df + + +def _safe_metabolite(smiles: str): + try: + return Metabolite(smiles) + except Exception: + return None + + +def _build_summary_from_columns(row, metadata_key_map): + summary = {} + for key, col in metadata_key_map.items(): + if col in row.index: + value = row[col] + if value is not None and not (isinstance(value, float) and np.isnan(value)): + summary[key] = value + return summary + + +def _resolve_device(device: str) -> str: + if device == "auto": + return "cuda:0" if torch.cuda.is_available() else "cpu" + return device + + +def _load_model_params(path: str | None) -> dict: + if path is None: + return {} + with open(path, "r") as fp: + return json.load(fp) + + +def _choose_loss(loss_name: str): + if loss_name == "graphwise_kl": + return GraphwiseKLLoss(reduction="mean"), {"mse": GraphwiseKLLossMetric} + if loss_name == "weighted_mse": + return WeightedMSELoss(), {"mse": WeightedMSEMetric} + if loss_name == "weighted_mae": + return WeightedMAELoss(), {"mae": WeightedMAEMetric} + if loss_name == "mse": + return torch.nn.MSELoss(), None + raise ValueError(f"Unknown loss: {loss_name}") + + +def main() -> None: + args = parse_args() + dev = _resolve_device(args.device) + np.seterr(invalid="ignore") + + index_col = None if args.no_index_col else args.index_col + loader = LibraryLoader() + df = ( + loader.load_from_csv(args.input) + if index_col == 0 + else pd.read_csv(args.input, index_col=index_col, low_memory=False) + ) + + if args.max_rows: + df = df.iloc[: args.max_rows].copy() + + df = _parse_dict_columns(df, [args.summary_col, args.peaks_col]) + + # Prepare encoders + overwrite_sets = {} + if args.instruments: + overwrite_sets["instrument"] = [ + x.strip() for x in args.instruments.split(",") if x.strip() + ] + if args.precursor_modes: + overwrite_sets["precursor_mode"] = [ + x.strip() for x in args.precursor_modes.split(",") if x.strip() + ] + if not overwrite_sets: + overwrite_sets = None + + node_encoder = AtomFeatureEncoder( + feature_list=["symbol", "num_hydrogen", "ring_type"] + ) + bond_encoder = BondFeatureEncoder(feature_list=["bond_type", "ring_type"]) + covariate_encoder = CovariateFeatureEncoder( + feature_list=[ + "collision_energy", + "molecular_weight", + "precursor_mode", + "instrument", + "element_composition", + ], + sets_overwrite=overwrite_sets, + ) + rt_encoder = CovariateFeatureEncoder( + feature_list=[ + "molecular_weight", + "precursor_mode", + "instrument", + "element_composition", + ], + sets_overwrite=overwrite_sets, + ) + covariate_encoder.normalize_features["collision_energy"]["max"] = ( + args.ce_upper_limit + ) + covariate_encoder.normalize_features["molecular_weight"]["max"] = ( + args.weight_upper_limit + ) + rt_encoder.normalize_features["molecular_weight"]["max"] = args.weight_upper_limit + + metadata_key_map = { + "name": "Name", + "collision_energy": "CE", + "instrument": "Instrument_type", + "precursor_mode": "Precursor_type", + "precursor_mz": "PrecursorMZ", + "retention_time": "RETENTIONTIME", + "ccs": "CCS", + } + + # Build metabolites + metabolites = [] + invalid_rows = [] + for idx, row in df.iterrows(): + smiles = row.get(args.smiles_col) + if smiles is None or (isinstance(smiles, float) and np.isnan(smiles)): + invalid_rows.append(idx) + continue + mol = _safe_metabolite(smiles) + if mol is None: + invalid_rows.append(idx) + continue + mol.create_molecular_structure_graph() + mol.compute_graph_attributes(node_encoder, bond_encoder) + + if args.group_id_col in df.columns: + try: + mol.set_id(int(row[args.group_id_col])) + except Exception: + pass + + summary = None + if args.summary_col in df.columns: + summary = row.get(args.summary_col) + if summary is None: + summary = _build_summary_from_columns(row, metadata_key_map) + + try: + mol.add_metadata(summary, covariate_encoder, rt_encoder) + except Exception: + invalid_rows.append(idx) + continue + + if args.loss_weight_col in df.columns: + try: + mol.set_loss_weight(float(row[args.loss_weight_col])) + except Exception: + mol.set_loss_weight(1.0) + else: + mol.set_loss_weight(1.0) + + metabolites.append(mol) + df.at[idx, "Metabolite"] = mol + + if invalid_rows: + df = df.drop(index=invalid_rows) + print(f"Dropped {len(invalid_rows)} invalid rows.") + + # Fragmentation trees + if args.use_frag_index: + mindex = MetaboliteIndex() + mindex.index_metabolites(df["Metabolite"]) + mindex.create_fragmentation_trees(depth=args.fragmentation_depth) + mindex.add_fragmentation_trees_to_metabolite_list( + df["Metabolite"], graph_mismatch_policy=args.graph_mismatch_policy + ) + else: + df["Metabolite"].apply(lambda x: x.fragment_MOL(depth=args.fragmentation_depth)) + + # Match peaks to fragments + ppm_default = args.ppm if args.ppm is not None else DEFAULT_PPM + match_invalid = [] + for idx, row in df.iterrows(): + peaks = row.get(args.peaks_col) + if not isinstance(peaks, dict): + match_invalid.append(idx) + continue + mz = peaks.get("mz") + intensity = peaks.get("intensity") + if mz is None or intensity is None or len(mz) == 0: + match_invalid.append(idx) + continue + tol = ppm_default + if args.ppm_col in df.columns: + try: + val = float(row[args.ppm_col]) + if not np.isnan(val): + tol = val + except Exception: + pass + try: + row["Metabolite"].match_fragments_to_peaks(mz, intensity, tolerance=tol) + except Exception: + match_invalid.append(idx) + + if match_invalid: + df = df.drop(index=match_invalid) + print(f"Dropped {len(match_invalid)} rows with invalid peaks.") + + df["num_peak_matches"] = df["Metabolite"].apply( + lambda x: x.match_stats["num_peak_matches"] + ) + if args.min_peak_matches > 0: + before = len(df) + df = df[df["num_peak_matches"] >= args.min_peak_matches] + print( + f"Filtered {before - len(df)} rows with < {args.min_peak_matches} peak matches." + ) + + # Train/val split + train_keys = [] + val_keys = [] + if args.datasplit_col in df.columns: + df_train = df[df[args.datasplit_col].isin([args.train_label, args.val_label])] + if args.group_id_col in df.columns: + train_keys = ( + df[df[args.datasplit_col] == args.train_label][args.group_id_col] + .dropna() + .unique() + .tolist() + ) + val_keys = ( + df[df[args.datasplit_col] == args.val_label][args.group_id_col] + .dropna() + .unique() + .tolist() + ) + else: + df_train = df + + # Geometric data + geo_data = [] + for _, row in df_train.iterrows(): + data = row["Metabolite"].as_geometric_data().to(dev) + if args.group_id_col in df_train.columns: + try: + data.group_id = int(row[args.group_id_col]) + except Exception: + pass + geo_data.append(data) + print(f"Prepared training/validation with {len(geo_data)} data points") + + # Model params + base_params = _load_model_params(args.model_params) + model_params = dict(base_params) + model_params.update( + { + "node_feature_layout": node_encoder.feature_numbers, + "edge_feature_layout": bond_encoder.feature_numbers, + "static_feature_dimension": geo_data[0]["static_edge_features"].shape[1], + "static_rt_feature_dimension": geo_data[0]["static_rt_features"].shape[1], + "output_dimension": len(DEFAULT_MODES) * 2, + "atom_features": node_encoder.feature_list, + "setup_features": covariate_encoder.feature_list, + "setup_features_categorical_set": covariate_encoder.categorical_sets, + "rt_features": rt_encoder.feature_list, + "prepare_additional_layers": args.with_rt or args.with_ccs, + "rt_supported": args.with_rt, + "ccs_supported": args.with_ccs, + } + ) + + # Initialize or resume model + if args.resume: + state_path = args.resume.replace(".pt", "_state.pt") + params_path = args.resume.replace(".pt", "_params.json") + if os.path.exists(state_path) and os.path.exists(params_path): + model = FioraModel.load_from_state_dict(args.resume).to(dev) + else: + model = FioraModel.load(args.resume).to(dev) + else: + model = FioraModel(model_params).to(dev) + + if (args.with_rt or args.with_ccs) and not model.model_params.get( + "prepare_additional_layers", False + ): + raise RuntimeError( + "Model does not include RT/CCS heads but --with-rt/--with-ccs was set." + ) + + loss_fn, metric_dict = _choose_loss(args.loss) + + split_by_group = args.split_by_group and args.group_id_col in df_train.columns + only_training = len(val_keys) == 0 and not args.use_validation_mask + + trainer = SpectralTrainer( + geo_data, + y_tag=args.y_label, + problem_type="regression", + train_val_split=args.train_val_split, + split_by_group=split_by_group, + only_training=only_training, + train_keys=train_keys, + val_keys=val_keys, + metric_dict=metric_dict, + seed=args.seed, + device=dev, + num_workers=args.num_workers, + ) + + optimizer = torch.optim.Adam( + model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay + ) + + scheduler = None + if args.scheduler == "plateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + patience=args.scheduler_patience, + factor=args.scheduler_factor, + mode="min", + verbose=True, + ) + + output_path = args.output + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + + checkpoints = trainer.train( + model, + optimizer, + loss_fn, + scheduler=scheduler, + batch_size=args.batch_size, + epochs=args.epochs, + val_every_n_epochs=args.val_every, + use_validation_mask=args.use_validation_mask, + with_RT=args.with_rt, + with_CCS=args.with_ccs, + rt_metric=args.rt_metric, + mask_name=args.validation_mask_name, + save_path=output_path, + tag="train", + ) + + print(f"Finished training. Best checkpoint: {checkpoints['file']}") + + +if __name__ == "__main__": + main() diff --git a/fiora/visualization/define_colors.py b/fiora/visualization/define_colors.py index 3969296..a0e0f46 100644 --- a/fiora/visualization/define_colors.py +++ b/fiora/visualization/define_colors.py @@ -16,7 +16,9 @@ def mix_colors(c1, c2, ratio=1.0): col_mistle_bright = (0.5 * col_mistle[0], 0.9 * col_mistle[1], 0.9 * col_mistle[2]) col_decoy = sns.color_palette("Set3")[1] -col_spectrast = sns.color_palette(palette="Set3")[4] # (1,1,1) #sns.color_palette(palette="Set3")[4] +col_spectrast = sns.color_palette(palette="Set3")[ + 4 +] # (1,1,1) #sns.color_palette(palette="Set3")[4] # col_spectrast = (1,1,1) #sns.color_palette(palette="Set3")[4] col_xtandem = sns.color_palette(palette="Set3")[3] col_msf = sns.color_palette(palette="Set3")[3] @@ -29,7 +31,13 @@ def mix_colors(c1, c2, ratio=1.0): palette = sns.color_palette("colorblind") color_palette = palette -C = {"green": palette[2], "orange": palette[1], "blue": palette[0], "red": palette[3], "yellow": palette[8]} +C = { + "green": palette[2], + "orange": palette[1], + "blue": palette[0], + "red": palette[3], + "yellow": palette[8], +} C["g"] = C["green"] C["o"] = C["orange"] C["b"] = C["blue"] @@ -37,7 +45,9 @@ def mix_colors(c1, c2, ratio=1.0): C["lightgreen"] = [1.25 * x for x in C["green"]] C["darkgreen"] = [0.75 * x for x in C["green"]] C["ivorygreen"] = mix_colors(C["green"], matplotlib.colors.to_rgb("ivory"), ratio=0.5) -C["chocolategreen"] = mix_colors(C["green"], matplotlib.colors.to_rgb("chocolate"), ratio=1.5) +C["chocolategreen"] = mix_colors( + C["green"], matplotlib.colors.to_rgb("chocolate"), ratio=1.5 +) PRINT_COL = { @@ -46,7 +56,7 @@ def mix_colors(c1, c2, ratio=1.0): "green": "\033[92m", "yellow": "\033[93m", "red": "\033[91m", - "end": "\033[00m" + "end": "\033[00m", } ELEMENT_COLORS = { @@ -68,10 +78,10 @@ def mix_colors(c1, c2, ratio=1.0): lightpink = (81, 55, 52) lightpink_hex = "#cf8c85" newpink = (255, 64, 85) # (light blue + Red 255) -newpink_hex = "#ffa3d6" +newpink_hex = "#ffa3d6" newnewpink = (242, 163, 214) # (light blue + Red 255) -newnewpink_hex = "#F2A3D6" +newnewpink_hex = "#F2A3D6" wippinkbutbestsofar = (221, 140, 150) wippinkbutbestsofar_hex = "#DD8C96" @@ -89,11 +99,19 @@ def mix_colors(c1, c2, ratio=1.0): lightgreen_hex = "#ACF39D" wine_hex = "#773344" -bluepink = sns.color_palette([lightblue_hex, lightpink_hex, black_hex, lightgreen_hex, wine_hex], as_cmap=True) -bluepink_grad = sns.diverging_palette(17.7, 245.8, s=75, l=50, sep=1, n=6, center='light', as_cmap=True) -bluepink_grad8 = sns.diverging_palette(17.7, 245.8, s=75, l=50, sep=1, n=8, center='light', as_cmap=False) +bluepink = sns.color_palette( + [lightblue_hex, lightpink_hex, black_hex, lightgreen_hex, wine_hex], as_cmap=True +) +bluepink_grad = sns.diverging_palette( + 17.7, 245.8, s=75, l=50, sep=1, n=6, center="light", as_cmap=True +) +bluepink_grad8 = sns.diverging_palette( + 17.7, 245.8, s=75, l=50, sep=1, n=8, center="light", as_cmap=False +) + +tri_palette = ["gray", bluepink[0], bluepink[1]] + -tri_palette=["gray", bluepink[0], bluepink[1]] def magma(steps): return sns.color_palette("magma_r", steps) @@ -107,24 +125,52 @@ def define_figure_style(style: str, palette_steps=8): # Define figure styles if "magma-white": color_palette = sns.color_palette("magma_r", palette_steps) - sns.set_theme(style="whitegrid", - rc={'axes.edgecolor': 'black', 'ytick.left': True, 'xtick.bottom': True, 'xtick.color': 'black', - "axes.spines.bottom": True, "axes.spines.right": True, "axes.spines.top": True, - "axes.spines.left": True}) + sns.set_theme( + style="whitegrid", + rc={ + "axes.edgecolor": "black", + "ytick.left": True, + "xtick.bottom": True, + "xtick.color": "black", + "axes.spines.bottom": True, + "axes.spines.right": True, + "axes.spines.top": True, + "axes.spines.left": True, + }, + ) return color_palette + def set_theme(): - sns.set_theme(style="darkgrid", - rc={'axes.edgecolor': 'black', 'ytick.left': True, 'xtick.bottom': True, 'xtick.color': 'black', - "axes.spines.bottom": True, "axes.spines.right": False, "axes.spines.top": False, - "axes.spines.left": True}) + sns.set_theme( + style="darkgrid", + rc={ + "axes.edgecolor": "black", + "ytick.left": True, + "xtick.bottom": True, + "xtick.color": "black", + "axes.spines.bottom": True, + "axes.spines.right": False, + "axes.spines.top": False, + "axes.spines.left": True, + }, + ) def set_light_theme(): - sns.set_theme(style="whitegrid", - rc={'axes.edgecolor': 'black', 'ytick.left': True, 'xtick.bottom': True, 'xtick.color': 'black', - "axes.spines.bottom": True, "axes.spines.right": True, "axes.spines.top": True, - "axes.spines.left": True}) + sns.set_theme( + style="whitegrid", + rc={ + "axes.edgecolor": "black", + "ytick.left": True, + "xtick.bottom": True, + "xtick.color": "black", + "axes.spines.bottom": True, + "axes.spines.right": True, + "axes.spines.top": True, + "axes.spines.left": True, + }, + ) def reset_matplotlib(): @@ -132,8 +178,18 @@ def reset_matplotlib(): def set_all_font_sizes(size): - zs = ['font.size', 'axes.labelsize', 'axes.titlesize', 'legend.fontsize', "xtick.labelsize", "xtick.major.size", - "xtick.minor.size", "ytick.labelsize", "ytick.major.size", "ytick.minor.size"] + zs = [ + "font.size", + "axes.labelsize", + "axes.titlesize", + "legend.fontsize", + "xtick.labelsize", + "xtick.major.size", + "xtick.minor.size", + "ytick.labelsize", + "ytick.major.size", + "ytick.minor.size", + ] for z in zs: plt.rcParams[z] = size @@ -145,9 +201,7 @@ def set_plt_params_to_default(): def adjust_box_widths_for_all_axes(fig, fac): for ax in fig.axes: - for c in ax.get_children(): - if isinstance(c, PathPatch): # getting current width of box: p = c.get_path() @@ -155,12 +209,12 @@ def adjust_box_widths_for_all_axes(fig, fac): verts_sub = verts[:-1] xmin = np.min(verts_sub[:, 0]) xmax = np.max(verts_sub[:, 0]) - xmid = 0.5*(xmin+xmax) - xhalf = 0.5*(xmax - xmin) + xmid = 0.5 * (xmin + xmax) + xhalf = 0.5 * (xmax - xmin) # setting new width of box - xmin_new = xmid-fac*xhalf - xmax_new = xmid+fac*xhalf + xmin_new = xmid - fac * xhalf + xmax_new = xmid + fac * xhalf verts_sub[verts_sub[:, 0] == xmin, 0] = xmin_new verts_sub[verts_sub[:, 0] == xmax, 0] = xmax_new @@ -168,30 +222,30 @@ def adjust_box_widths_for_all_axes(fig, fac): for l in ax.lines: if np.all(l.get_xdata() == [xmin, xmax]): l.set_xdata([xmin_new, xmax_new]) - -def adjust_box_widths(ax, fac): - for c in ax.get_children(): - - if isinstance(c, PathPatch): - # getting current width of box: - p = c.get_path() - verts = p.vertices - verts_sub = verts[:-1] - xmin = np.min(verts_sub[:, 0]) - xmax = np.max(verts_sub[:, 0]) - xmid = 0.5*(xmin+xmax) - xhalf = 0.5*(xmax - xmin) - # setting new width of box - xmin_new = xmid-fac*xhalf - xmax_new = xmid+fac*xhalf - verts_sub[verts_sub[:, 0] == xmin, 0] = xmin_new - verts_sub[verts_sub[:, 0] == xmax, 0] = xmax_new - # setting new width of median line - for l in ax.lines: - if np.all(l.get_xdata() == [xmin, xmax]): - l.set_xdata([xmin_new, xmax_new]) +def adjust_box_widths(ax, fac): + for c in ax.get_children(): + if isinstance(c, PathPatch): + # getting current width of box: + p = c.get_path() + verts = p.vertices + verts_sub = verts[:-1] + xmin = np.min(verts_sub[:, 0]) + xmax = np.max(verts_sub[:, 0]) + xmid = 0.5 * (xmin + xmax) + xhalf = 0.5 * (xmax - xmin) + + # setting new width of box + xmin_new = xmid - fac * xhalf + xmax_new = xmid + fac * xhalf + verts_sub[verts_sub[:, 0] == xmin, 0] = xmin_new + verts_sub[verts_sub[:, 0] == xmax, 0] = xmax_new + + # setting new width of median line + for l in ax.lines: + if np.all(l.get_xdata() == [xmin, xmax]): + l.set_xdata([xmin_new, xmax_new]) def adjust_bar_widths(ax, fac): @@ -204,5 +258,3 @@ def adjust_bar_widths(ax, fac): new_width = fac * bar_width bar.set_width(new_width) bar.set_x(bar_center - new_width / 2) - - diff --git a/fiora/visualization/inspect_mgf_file.py b/fiora/visualization/inspect_mgf_file.py index 7cbf038..26ca4c9 100644 --- a/fiora/visualization/inspect_mgf_file.py +++ b/fiora/visualization/inspect_mgf_file.py @@ -5,10 +5,16 @@ import pandas as pd parser = argparse.ArgumentParser() -parser.add_argument("-v", "--verbose", help="increase output verbosity", action="store_true") -parser.add_argument("-i", "--infile", - help="path/file.mgf to the search file, which will be inspected", - type=str, required=True) +parser.add_argument( + "-v", "--verbose", help="increase output verbosity", action="store_true" +) +parser.add_argument( + "-i", + "--infile", + help="path/file.mgf to the search file, which will be inspected", + type=str, + required=True, +) args = parser.parse_args() @@ -22,17 +28,19 @@ # CHARGE -df['CHARGE'] = df['CHARGE'].apply(str) -df['CHARGE'] = pd.Categorical(df['CHARGE'], sorted(df.CHARGE.unique())) +df["CHARGE"] = df["CHARGE"].apply(str) +df["CHARGE"] = pd.Categorical(df["CHARGE"], sorted(df.CHARGE.unique())) -ax = sns.countplot(data=df, x="CHARGE", palette=color_palette, edgecolor="black") #order=df['CHARGE'].value_counts().iloc[:10].index) +ax = sns.countplot( + data=df, x="CHARGE", palette=color_palette, edgecolor="black" +) # order=df['CHARGE'].value_counts().iloc[:10].index) plt.title("MS/MS charge distribution") plt.show() # Precursor m/z -df['precursor_mz'] = df['PEPMASS'].apply(lambda x: float(x.split(' ')[0])) +df["precursor_mz"] = df["PEPMASS"].apply(lambda x: float(x.split(" ")[0])) sns.boxplot(data=df, y="precursor_mz", x="CHARGE", palette=color_palette) plt.title("MS/MS precursor mz range over charge") @@ -40,8 +48,8 @@ # Num of peaks -df['num_peaks'] = df['peaks'].apply(lambda p: len(p['mz'])) +df["num_peaks"] = df["peaks"].apply(lambda p: len(p["mz"])) sns.boxplot(data=df, y="num_peaks", x="CHARGE", palette=color_palette) plt.title("MS/MS number of peaks per spectrum over charge") -plt.show() \ No newline at end of file +plt.show() diff --git a/fiora/visualization/plot_spectrum.py b/fiora/visualization/plot_spectrum.py index c524eb9..424f254 100644 --- a/fiora/visualization/plot_spectrum.py +++ b/fiora/visualization/plot_spectrum.py @@ -8,21 +8,26 @@ import matplotlib.pyplot as plt parser = argparse.ArgumentParser() -parser.add_argument("-f", "--file1", help="file where spectrum is contained (.mgf or .msp)", type=str, - default="/home/ynowatzk/data/9MM/mgf/9MM_FASP.mgf") -parser.add_argument("-n", "--name1", help="exact name of spectrum", type=str, - required=True) +parser.add_argument( + "-f", + "--file1", + help="file where spectrum is contained (.mgf or .msp)", + type=str, + default="/home/ynowatzk/data/9MM/mgf/9MM_FASP.mgf", +) +parser.add_argument( + "-n", "--name1", help="exact name of spectrum", type=str, required=True +) -parser.add_argument("-f2", "--file2", help="file where lower spectrum is found", - type=str) +parser.add_argument( + "-f2", "--file2", help="file where lower spectrum is found", type=str +) -parser.add_argument("-n2", "--name2", help="exact name of lower spectrum", - type=str) -parser.add_argument("-o", "--out", help="output file", - type=str) -#parser.add_argument("-a", "--annotate", help="perform spectrum annotation", action="store_true", default=False) -#parser.add_argument("-p", "--peptide", help="peptide", type=str, default="None") -#parser.add_argument("-c", "--charge", help="charge", type=int, default=0) +parser.add_argument("-n2", "--name2", help="exact name of lower spectrum", type=str) +parser.add_argument("-o", "--out", help="output file", type=str) +# parser.add_argument("-a", "--annotate", help="perform spectrum annotation", action="store_true", default=False) +# parser.add_argument("-p", "--peptide", help="peptide", type=str, default="None") +# parser.add_argument("-c", "--charge", help="charge", type=int, default=0) parser.add_argument("--fontsize", help="font size of the text", type=int) args = parser.parse_args() @@ -59,6 +64,13 @@ def read_spectrum_from_file(file, name): if args.file2 and args.name2: s2 = read_spectrum_from_file(args.file2, args.name2) - sv.plot_spectrum(s1, s2, title=args.name1 + " matched by " + args.name2.split("/")[0], out=args.out) #,annotate=args.annotate, peptide=args.peptide, charge=args.charge, font_size=args.fontsize) + sv.plot_spectrum( + s1, + s2, + title=args.name1 + " matched by " + args.name2.split("/")[0], + out=args.out, + ) # ,annotate=args.annotate, peptide=args.peptide, charge=args.charge, font_size=args.fontsize) else: - sv.plot_spectrum(s1, title=args.name1, out=args.out, show=True)#, font_size=args.fontsize) \ No newline at end of file + sv.plot_spectrum( + s1, title=args.name1, out=args.out, show=True + ) # , font_size=args.fontsize) diff --git a/fiora/visualization/spectrum_visualizer.py b/fiora/visualization/spectrum_visualizer.py index 0946fa3..396f789 100644 --- a/fiora/visualization/spectrum_visualizer.py +++ b/fiora/visualization/spectrum_visualizer.py @@ -11,7 +11,9 @@ # From spectrum utils issue https://github.com/bittremieux/spectrum_utils/issues/56 # Overwrite get_theoretical_fragments -def get_theoretical_fragments(proteoform, ion_types=None, max_ion_charge=None, neutral_losses=None): +def get_theoretical_fragments( + proteoform, ion_types=None, max_ion_charge=None, neutral_losses=None +): fragments_masses = [] for mod in proteoform.modifications: fragment = fa.FragmentAnnotation(ion_type="w", charge=1) @@ -19,26 +21,32 @@ def get_theoretical_fragments(proteoform, ion_types=None, max_ion_charge=None, n fragments_masses.append((fragment, mass)) return fragments_masses + def set_custom_annotation(): # Use the custom function to annotate the fragments fa.get_theoretical_fragments = get_theoretical_fragments - fa._supported_ions += "w" + fa._supported_ions += "w" # Set peak color for custom ion sup.colors["w"] = lightblue_hex + def set_default_peak_color(color): sup.colors[None] = color -def annotate_and_plot(spectrum, mz_fragments, with_grid: bool=False, ppm_tolerance: int=100, ax=None): - + +def annotate_and_plot( + spectrum, mz_fragments, with_grid: bool = False, ppm_tolerance: int = 100, ax=None +): + set_custom_annotation() # Instantiate Spectrum and annotate with proforma string format (e.g. X[+9.99] ) - spectrum = sus.MsmsSpectrum("None", 0, 1, spectrum['peaks']['mz'], spectrum['peaks']['intensity']) + spectrum = sus.MsmsSpectrum( + "None", 0, 1, spectrum["peaks"]["mz"], spectrum["peaks"]["intensity"] + ) x_string = "".join([f"X[+{mz}]" for mz in sorted(mz_fragments)]) spectrum.annotate_proforma(x_string, ppm_tolerance, "ppm") - - + # Find ax and plot if not ax: ax = plt.gca() @@ -50,44 +58,78 @@ def annotate_and_plot(spectrum, mz_fragments, with_grid: bool=False, ppm_toleran return ax -def plot_spectrum(spectrum: Dict, second_spectrum: Dict|None=None, highlight_matches: bool=False, mz_matches: List[int]=[], facet_plot=False, ppm_tolerance: int=100, charge=0, title=None, out=None, with_grid=False, ax=None, show=False, color=None): - top_spectrum = sus.MsmsSpectrum("None", 0, charge, spectrum['peaks']['mz'], spectrum['peaks']['intensity']) + +def plot_spectrum( + spectrum: Dict, + second_spectrum: Dict | None = None, + highlight_matches: bool = False, + mz_matches: List[int] = [], + facet_plot=False, + ppm_tolerance: int = 100, + charge=0, + title=None, + out=None, + with_grid=False, + ax=None, + show=False, + color=None, +): + top_spectrum = sus.MsmsSpectrum( + "None", 0, charge, spectrum["peaks"]["mz"], spectrum["peaks"]["intensity"] + ) if color: set_default_peak_color(color) if not ax: fig, ax = plt.subplots(figsize=(12, 6)) # spectrum.set_mz_range(min_mz=0, max_mz=2000) if second_spectrum is not None: - bottom_spectrum = sus.MsmsSpectrum("None", 0, charge, second_spectrum['peaks']['mz'], second_spectrum['peaks']['intensity']) + bottom_spectrum = sus.MsmsSpectrum( + "None", + 0, + charge, + second_spectrum["peaks"]["mz"], + second_spectrum["peaks"]["intensity"], + ) if highlight_matches: set_custom_annotation() - x_string = "".join([f"X[+{mz}]" for mz in sorted(second_spectrum['peaks']['mz'])]) + x_string = "".join( + [f"X[+{mz}]" for mz in sorted(second_spectrum["peaks"]["mz"])] + ) top_spectrum.annotate_proforma(x_string, ppm_tolerance, "ppm") - x_string = "".join([f"X[+{mz}]" for mz in sorted(spectrum['peaks']['mz'])]) + x_string = "".join([f"X[+{mz}]" for mz in sorted(spectrum["peaks"]["mz"])]) bottom_spectrum.annotate_proforma(x_string, ppm_tolerance, "ppm") - + if facet_plot: - sup.facet(spec_top=top_spectrum, spec_mass_errors=top_spectrum, spec_bottom=bottom_spectrum, mass_errors_kws={"plot_unknown": False}) - else: # mirror plot - sup.mirror(spec_top=top_spectrum, spec_bottom=bottom_spectrum, ax=ax, spectrum_kws={"grid": with_grid}) - + sup.facet( + spec_top=top_spectrum, + spec_mass_errors=top_spectrum, + spec_bottom=bottom_spectrum, + mass_errors_kws={"plot_unknown": False}, + ) + else: # mirror plot + sup.mirror( + spec_top=top_spectrum, + spec_bottom=bottom_spectrum, + ax=ax, + spectrum_kws={"grid": with_grid}, + ) + if with_grid: ax.set_ylim(-1.075, 1.075) else: sns.despine(ax=ax) - if second_spectrum is not None: - ax.spines['bottom'].set_position(('outward', 10)) + if second_spectrum is not None: + ax.spines["bottom"].set_position(("outward", 10)) # Single spectrum else: - if highlight_matches and mz_matches: set_custom_annotation() x_string = "".join([f"X[+{mz}]" for mz in sorted(mz_matches)]) top_spectrum.annotate_proforma(x_string, ppm_tolerance, "ppm") - + sup.spectrum(top_spectrum, grid=with_grid, ax=ax) if with_grid: ax.set_ylim(0, 1.1) @@ -96,24 +138,37 @@ def plot_spectrum(spectrum: Dict, second_spectrum: Dict|None=None, highlight_mat plt.title(title) if out is not None: - plt.savefig(out) + plt.savefig(out) else: if show: plt.show() else: return ax -def plot_vector_spectrum(vec1, vec2, ax=None, title=None, y_label="probability", names= None): - v1 = pd.DataFrame({"range": names if names else range(len(vec1)), "prob": vec1, "group": "prob"}) - v2 = pd.DataFrame({"range": names if names else range(len(vec2)), "prob": vec2, "group": "pred"}) + +def plot_vector_spectrum( + vec1, vec2, ax=None, title=None, y_label="probability", names=None +): + v1 = pd.DataFrame( + {"range": names if names else range(len(vec1)), "prob": vec1, "group": "prob"} + ) + v2 = pd.DataFrame( + {"range": names if names else range(len(vec2)), "prob": vec2, "group": "pred"} + ) V = pd.concat([v1, v2]) if not ax: fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4)) - sns.barplot(ax=ax, data=V, y="prob", x="range", edgecolor="black", hue="group", linewidth=1.5) + sns.barplot( + ax=ax, + data=V, + y="prob", + x="range", + edgecolor="black", + hue="group", + linewidth=1.5, + ) ax.set_xlabel("") ax.set_ylabel(y_label) ax.set_title(title) return ax - - diff --git a/lib_loader/casmi16_loader.ipynb b/lib_loader/casmi16_loader.ipynb index 2ffe10f..c31ee08 100644 --- a/lib_loader/casmi16_loader.ipynb +++ b/lib_loader/casmi16_loader.ipynb @@ -31,12 +31,14 @@ ], "source": [ "import sys\n", - "print(f'Working with Python {sys.version}')\n", + "\n", + "print(f\"Working with Python {sys.version}\")\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import importlib\n", - "#import swifter\n", + "\n", + "# import swifter\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import collections\n", @@ -52,17 +54,20 @@ "\n", "# Deep Learning\n", "import sklearn\n", - "#import spektral\n", + "\n", + "# import spektral\n", "from sklearn.model_selection import train_test_split\n", + "\n", "# Keras\n", "from sklearn.model_selection import train_test_split\n", - "#import stellargraph as sg\n", - "from rdkit import RDLogger\n", "\n", + "# import stellargraph as sg\n", + "from rdkit import RDLogger\n", "\n", "\n", "# Load Modules\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "import fiora.IO.mspReader as mspReader\n", "import fiora.IO.mgfReader as mgfReader\n", @@ -70,10 +75,10 @@ "import fiora.IO.molReader as molReader\n", "\n", "\n", - "RDLogger.DisableLog('rdApp.*')\n", + "RDLogger.DisableLog(\"rdApp.*\")\n", "\n", "\n", - "caffeine_smiles = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'\n", + "caffeine_smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", "caffeine_mol = Chem.MolFromSmiles(caffeine_smiles)\n", "\n", "caffeine_mol" @@ -113,7 +118,9 @@ } ], "source": [ - "library_directory = f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_negative_mgf\"\n", + "library_directory = (\n", + " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_negative_mgf\"\n", + ")\n", "!ls $library_directory" ] }, @@ -133,21 +140,23 @@ "outputs": [], "source": [ "df = []\n", - "library_directory = f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_negative_mgf\"\n", + "library_directory = (\n", + " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_negative_mgf\"\n", + ")\n", "\n", "for file in os.listdir(library_directory):\n", " if file.endswith(\".mgf\"):\n", " data = mgfReader.read(os.path.join(library_directory, file), sep=\"\\t\")[0]\n", - " data['FILE'] = file\n", - " data['Precursor_type'] = \"[M-H]-\"\n", + " data[\"FILE\"] = file\n", + " data[\"Precursor_type\"] = \"[M-H]-\"\n", " df += [data]\n", "\n", "library_directory = library_directory.replace(\"negative\", \"positive\")\n", "for file in os.listdir(library_directory):\n", " if file.endswith(\".mgf\"):\n", " data = mgfReader.read(os.path.join(library_directory, file), sep=\"\\t\")[0]\n", - " data['FILE'] = file\n", - " data['Precursor_type'] = \"[M+H]+\"\n", + " data[\"FILE\"] = file\n", + " data[\"Precursor_type\"] = \"[M+H]+\"\n", " df += [data]\n", "\n", "df = pd.DataFrame(df)" @@ -240,7 +249,9 @@ } ], "source": [ - "solution = pd.read_csv(os.path.join(library_directory, \"solutions_casmi2016_cat2and3.csv\"))\n", + "solution = pd.read_csv(\n", + " os.path.join(library_directory, \"solutions_casmi2016_cat2and3.csv\")\n", + ")\n", "# solution = solution[solution[\"ION_MODE\"] == \" POSITIVE\"]\n", "# solution.reset_index(inplace=True, drop=True)\n", "df = pd.concat([df, solution], axis=1)\n", @@ -265,7 +276,7 @@ ], "source": [ "# Check that challenges and solutions are aligned correctly\n", - "df.apply(lambda x: x[\"FILE\"].split('.')[0] == x[\"ChallengeName\"], axis=1).all()" + "df.apply(lambda x: x[\"FILE\"].split(\".\")[0] == x[\"ChallengeName\"], axis=1).all()" ] }, { @@ -287,17 +298,18 @@ } ], "source": [ - "\n", - "library_candidates_directory = f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_Candidates\"\n", + "library_candidates_directory = (\n", + " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_Candidates\"\n", + ")\n", "candidates = []\n", "for file in os.listdir(library_candidates_directory):\n", " if file.endswith(\".csv\"):\n", " data = pd.read_csv(os.path.join(library_candidates_directory, file), sep=\",\")\n", " SMILES = list(data[\"SMILES\"])\n", - " d = {'cFILE': file, 'Candidates': SMILES}\n", + " d = {\"cFILE\": file, \"Candidates\": SMILES}\n", " candidates += [d]\n", "\n", - "candidates = pd.DataFrame(candidates) #.loc[81:].reset_index(drop=True)\n", + "candidates = pd.DataFrame(candidates) # .loc[81:].reset_index(drop=True)\n", "print(candidates.head())" ] }, @@ -319,7 +331,7 @@ ], "source": [ "df = pd.concat([df, candidates], axis=1)\n", - "df.apply(lambda x: x[\"cFILE\"].split('.')[0] == x[\"ChallengeName\"], axis=1).all()" + "df.apply(lambda x: x[\"cFILE\"].split(\".\")[0] == x[\"ChallengeName\"], axis=1).all()" ] }, { @@ -361,15 +373,14 @@ "save_df = False\n", "name = \"casmi16_challenges_combined.csv\"\n", "\n", - "library_directory = '/'.join(library_directory.split('/')[:-1])\n", + "library_directory = \"/\".join(library_directory.split(\"/\")[:-1])\n", "print(library_directory)\n", "if save_df:\n", - " file = os.path.join(library_directory, name)\n", - " print(\"saving to \", file)\n", - " df.to_csv(file)\n", - " \n", - " #df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")\n", - "\n" + " file = os.path.join(library_directory, name)\n", + " print(\"saving to \", file)\n", + " df.to_csv(file)\n", + "\n", + " # df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" ] }, { @@ -542,11 +553,15 @@ "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", - " df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"][[\"ChallengeName\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")\n", - " \n", + " df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"][[\"ChallengeName\", \"SMILES\"]].to_csv(\n", + " file, index=False, header=False, sep=\" \"\n", + " )\n", + "\n", " name = name.replace(\"negative\", \"positive\")\n", " file = os.path.join(cfm_directory, name)\n", - " df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"][[\"ChallengeName\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")" + " df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"][[\"ChallengeName\", \"SMILES\"]].to_csv(\n", + " file, index=False, header=False, sep=\" \"\n", + " )" ] }, { @@ -806,7 +821,15 @@ "df[\"MOL\"] = df[\"SMILES\"].apply(Chem.MolFromSmiles)\n", "df[\"formula\"] = df[\"MOL\"].apply(rdMolDescriptors.CalcMolFormula)\n", "df[\"dataset\"] = \"CASMI16\"\n", - "df = df.rename(columns={\"FILE\": \"spec\", \"ChallengeName\": \"name\", \"Precursor_type\": \"ionization\", \"SMILES\": \"smiles\", \"INCHIKEY\": \"inchikey\"})" + "df = df.rename(\n", + " columns={\n", + " \"FILE\": \"spec\",\n", + " \"ChallengeName\": \"name\",\n", + " \"Precursor_type\": \"ionization\",\n", + " \"SMILES\": \"smiles\",\n", + " \"INCHIKEY\": \"inchikey\",\n", + " }\n", + ")" ] }, { @@ -818,9 +841,13 @@ "save_df = False\n", "if save_df:\n", " output_file = f\"{home}/data/metabolites/ms-pred/casmi16_labels.tsv\"\n", - " df[[\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\",\t\"smiles\", \"inchikey\"]].to_csv(output_file, index=False, sep=\"\\t\")\n", + " df[\n", + " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", + " ].to_csv(output_file, index=False, sep=\"\\t\")\n", " output_file = f\"{home}/data/metabolites/ms-pred/casmi16_positive_labels.tsv\"\n", - " df[df[\"ionization\"] == \"[M+H]+\"][[\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\",\t\"smiles\", \"inchikey\"]].to_csv(output_file, index=False, sep=\"\\t\")" + " df[df[\"ionization\"] == \"[M+H]+\"][\n", + " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", + " ].to_csv(output_file, index=False, sep=\"\\t\")" ] }, { @@ -856,20 +883,27 @@ "\n", "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", - "setup_encoder = SetupFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\"])\n", + "setup_encoder = SetupFeatureEncoder(\n", + " feature_list=[\"collision_energy\", \"molecular_weight\"]\n", + ")\n", "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", "\n", "\n", - "df[\"CE\"] = 35.0 # 20/35/50\n", + "df[\"CE\"] = 35.0 # 20/35/50\n", "df[\"Instrument_type\"] = \"HCD\"\n", "df[\"Ionization\"] = \"ESI-MS/MS\"\n", - "metadata_key_map = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " #\"precursor_mz\": \"PrecursorMZ\"\n", - " }\n", - "df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder=None), axis=1)" + "metadata_key_map = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"ionization\": \"Ionization\",\n", + " # \"precursor_mz\": \"PrecursorMZ\"\n", + "}\n", + "df[\"summary\"] = df.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", + ")\n", + "df.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder=None), axis=1\n", + ")" ] }, { @@ -902,7 +936,7 @@ " new_id = len(metabolite_id_map)\n", " metabolite.id = new_id\n", " metabolite_id_map[new_id] = metabolite\n", - " \n", + "\n", "print(f\"Found {len(metabolite_id_map)} unique molecular structures.\")" ] }, @@ -931,7 +965,9 @@ "example = df.loc[EXAMPLE_ID]\n", "m = example[\"Metabolite\"]\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + ")\n", "set_light_theme()\n", "\n", "img = m.draw(ax=axs[0])\n", @@ -994,7 +1030,9 @@ "example = df.loc[EXAMPLE_ID]\n", "m = example[\"Metabolite\"]\n", "\n", - "fig, axs = plt.subplots(1, 4, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1,1,1]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 4, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1, 1, 1]}, sharey=False\n", + ")\n", "set_light_theme()\n", "\n", "img = m.draw(ax=axs[0])\n", @@ -1009,7 +1047,9 @@ "example = df.loc[EXAMPLE_ID]\n", "m = example[\"Metabolite\"]\n", "\n", - "fig, axs = plt.subplots(1, 4, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1,1,1]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 4, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1, 1, 1]}, sharey=False\n", + ")\n", "set_light_theme()\n", "\n", "img = m.draw(ax=axs[0])\n", @@ -1023,7 +1063,9 @@ "example = df.loc[EXAMPLE_ID]\n", "m = example[\"Metabolite\"]\n", "\n", - "fig, axs = plt.subplots(1, 4, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1,1,1]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 4, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1, 1, 1]}, sharey=False\n", + ")\n", "set_light_theme()\n", "\n", "img = m.draw(ax=axs[0])\n", @@ -1037,7 +1079,9 @@ "example = df.loc[EXAMPLE_ID]\n", "m = example[\"Metabolite\"]\n", "\n", - "fig, axs = plt.subplots(1, 4, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1,1,1]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 4, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1, 1, 1]}, sharey=False\n", + ")\n", "set_light_theme()\n", "\n", "img = m.draw(ax=axs[0])\n", @@ -1086,13 +1130,32 @@ "\n", "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), sharey=False)\n", "\n", - "edges_bond_types = [item for items in list(df[\"Metabolite\"].apply(lambda x: getattr(x, \"edge_bond_names\"))) for item in items]\n", - "bond_types = {bond: edges_bond_types.count(bond) for bond in np.unique(edges_bond_types)}\n", + "edges_bond_types = [\n", + " item\n", + " for items in list(df[\"Metabolite\"].apply(lambda x: getattr(x, \"edge_bond_names\")))\n", + " for item in items\n", + "]\n", + "bond_types = {\n", + " bond: edges_bond_types.count(bond) for bond in np.unique(edges_bond_types)\n", + "}\n", "print(bond_types)\n", "print(list(bond_types.values()))\n", "\n", - "sns.barplot(ax=axs[0], x=list(bond_types.keys()), y=list(bond_types.values()), palette=color_palette, edgecolor=\"black\", linewidth=1.5)\n", - "_,labels,autotexts = axs[1].pie(list(bond_types.values()), labels=list(bond_types.keys()), colors=color_palette, wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5}, autopct='%1.0f%%')\n", + "sns.barplot(\n", + " ax=axs[0],\n", + " x=list(bond_types.keys()),\n", + " y=list(bond_types.values()),\n", + " palette=color_palette,\n", + " edgecolor=\"black\",\n", + " linewidth=1.5,\n", + ")\n", + "_, labels, autotexts = axs[1].pie(\n", + " list(bond_types.values()),\n", + " labels=list(bond_types.keys()),\n", + " colors=color_palette,\n", + " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + " autopct=\"%1.0f%%\",\n", + ")\n", "\n", "for i in range(len(labels)):\n", " if labels[i].get_text() == \"TRIPLE\":\n", @@ -1134,20 +1197,43 @@ ], "source": [ "from modules.visualization.define_colors import define_figure_style\n", + "\n", "color_palette = define_figure_style(\"magma_white\")\n", "\n", "\n", - "elems = [e for mol in list(df[\"Metabolite\"].apply(lambda x: getattr(x, \"node_elements\")).values) for e in mol]\n", + "elems = [\n", + " e\n", + " for mol in list(\n", + " df[\"Metabolite\"].apply(lambda x: getattr(x, \"node_elements\")).values\n", + " )\n", + " for e in mol\n", + "]\n", "elem_types = {e: elems.count(e) for e in np.unique(elems)}\n", "\n", "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), sharey=False)\n", - "sns.barplot(ax=axs[0], x=list(elem_types.keys()), y=list(elem_types.values()), palette=color_palette, edgecolor=\"black\", linewidth=1.5)\n", - "_,labels,autotexts = axs[1].pie(list(elem_types.values()), colors=color_palette, wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5}, autopct='%1.0f%%')\n", + "sns.barplot(\n", + " ax=axs[0],\n", + " x=list(elem_types.keys()),\n", + " y=list(elem_types.values()),\n", + " palette=color_palette,\n", + " edgecolor=\"black\",\n", + " linewidth=1.5,\n", + ")\n", + "_, labels, autotexts = axs[1].pie(\n", + " list(elem_types.values()),\n", + " colors=color_palette,\n", + " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + " autopct=\"%1.0f%%\",\n", + ")\n", "\n", "axs[1].legend(list(elem_types.keys()))\n", "plt.show()\n", "print(elem_types)\n", - "cno = (elem_types[\"C\"]+elem_types[\"O\"]+elem_types[\"N\"])*100 / sum(elem_types.values())\n", + "cno = (\n", + " (elem_types[\"C\"] + elem_types[\"O\"] + elem_types[\"N\"])\n", + " * 100\n", + " / sum(elem_types.values())\n", + ")\n", "print(f\"With {cno:.01f}% CNO\")" ] }, @@ -1175,12 +1261,33 @@ } ], "source": [ - "elems = [e for mol in list(df[\"Metabolite\"].apply(lambda x: [a.GetTotalNumHs() for a in getattr(x, \"atoms_in_order\")]).values) for e in mol]\n", + "elems = [\n", + " e\n", + " for mol in list(\n", + " df[\"Metabolite\"]\n", + " .apply(lambda x: [a.GetTotalNumHs() for a in getattr(x, \"atoms_in_order\")])\n", + " .values\n", + " )\n", + " for e in mol\n", + "]\n", "elem_types = {e: elems.count(e) for e in np.unique(elems)}\n", "\n", "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), sharey=False)\n", - "sns.barplot(ax=axs[0], x=list(elem_types.keys()), y=list(elem_types.values()), palette=color_palette, edgecolor=\"black\", linewidth=1.5)\n", - "_,labels,autotexts = axs[1].pie(list(elem_types.values()), colors=color_palette, labels=list(elem_types.keys()), wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5}, autopct='%1.0f%%')\n", + "sns.barplot(\n", + " ax=axs[0],\n", + " x=list(elem_types.keys()),\n", + " y=list(elem_types.values()),\n", + " palette=color_palette,\n", + " edgecolor=\"black\",\n", + " linewidth=1.5,\n", + ")\n", + "_, labels, autotexts = axs[1].pie(\n", + " list(elem_types.values()),\n", + " colors=color_palette,\n", + " labels=list(elem_types.keys()),\n", + " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + " autopct=\"%1.0f%%\",\n", + ")\n", "\n", "axs[1].legend(list(elem_types.keys()))\n", "plt.title(\"Number of bonded hydrogens\")\n", @@ -1207,11 +1314,22 @@ "source": [ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", - "d = {\"collision_energy\": df[\"Metabolite\"].apply(lambda x: x.metadata[\"collision_energy\"])}\n", - "\n", - "sns.histplot(ax=ax, data=d, x='collision_energy', color=\"blue\", fill=True, binwidth=2, edgecolor=\"black\", binrange=[0, 200])#, order=list(range(0,200)))\n", + "d = {\n", + " \"collision_energy\": df[\"Metabolite\"].apply(lambda x: x.metadata[\"collision_energy\"])\n", + "}\n", + "\n", + "sns.histplot(\n", + " ax=ax,\n", + " data=d,\n", + " x=\"collision_energy\",\n", + " color=\"blue\",\n", + " fill=True,\n", + " binwidth=2,\n", + " edgecolor=\"black\",\n", + " binrange=[0, 200],\n", + ") # , order=list(range(0,200)))\n", "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams['axes.edgecolor'] = 'black'\n", + "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", "plt.show()" ] }, @@ -1234,11 +1352,22 @@ "source": [ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", - "d = {\"molecular_weight\": df[\"Metabolite\"].apply(lambda x: x.metadata[\"molecular_weight\"])}\n", - "\n", - "sns.histplot(ax=ax, data=d, x='molecular_weight', color=color_palette[3], fill=True, binwidth=2, edgecolor=\"black\", binrange=[0, 1000])#, order=list(range(0,200)))\n", + "d = {\n", + " \"molecular_weight\": df[\"Metabolite\"].apply(lambda x: x.metadata[\"molecular_weight\"])\n", + "}\n", + "\n", + "sns.histplot(\n", + " ax=ax,\n", + " data=d,\n", + " x=\"molecular_weight\",\n", + " color=color_palette[3],\n", + " fill=True,\n", + " binwidth=2,\n", + " edgecolor=\"black\",\n", + " binrange=[0, 1000],\n", + ") # , order=list(range(0,200)))\n", "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams['axes.edgecolor'] = 'black'\n", + "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", "plt.show()" ] }, @@ -1258,8 +1387,14 @@ "source": [ "%%capture\n", "from modules.MOL.constants import PPM\n", + "\n", "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 *PPM), axis=1)" + "df.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -1268,10 +1403,20 @@ "metadata": {}, "outputs": [], "source": [ - "d100 = pd.DataFrame({\"num_peak_matches\": df[\"Metabolite\"].apply(lambda x: x.match_stats[\"num_peak_matches\"]),\n", - " \"num_non_precursor_matches\": df[\"Metabolite\"].apply(lambda x: x.match_stats[\"num_non_precursor_matches\"]),\n", - " \"num_peak_match_conflicts\": df[\"Metabolite\"].apply(lambda x: x.match_stats[\"num_peak_match_conflicts\"]),\n", - " \"group\": \"100 PPM\"})\n" + "d100 = pd.DataFrame(\n", + " {\n", + " \"num_peak_matches\": df[\"Metabolite\"].apply(\n", + " lambda x: x.match_stats[\"num_peak_matches\"]\n", + " ),\n", + " \"num_non_precursor_matches\": df[\"Metabolite\"].apply(\n", + " lambda x: x.match_stats[\"num_non_precursor_matches\"]\n", + " ),\n", + " \"num_peak_match_conflicts\": df[\"Metabolite\"].apply(\n", + " lambda x: x.match_stats[\"num_peak_match_conflicts\"]\n", + " ),\n", + " \"group\": \"100 PPM\",\n", + " }\n", + ")" ] }, { @@ -1300,36 +1445,42 @@ "source": [ "# TODO Implement conflict solver\n", "\n", - "coverage_tracker = {\"counts\": [], \"all\": [], \"coverage\": [], \"fragment_only_coverage\": []}\n", + "coverage_tracker = {\n", + " \"counts\": [],\n", + " \"all\": [],\n", + " \"coverage\": [],\n", + " \"fragment_only_coverage\": [],\n", + "}\n", "\n", "drop_index = []\n", - "for i,d in df.iterrows():\n", + "for i, d in df.iterrows():\n", " M = d[\"Metabolite\"]\n", - " \n", - " \n", + "\n", " coverage_tracker[\"counts\"] += [M.match_stats[\"counts\"]]\n", " coverage_tracker[\"all\"] += [M.match_stats[\"ms_all_counts\"]]\n", " coverage_tracker[\"fragment_only_coverage\"] += [M.match_stats[\"coverage_wo_prec\"]]\n", " coverage_tracker[\"coverage\"] += [M.match_stats[\"coverage\"]]\n", - " \n", - " #if M.edge_break_prob_wo_precursor.sum() <= 0.01:\n", + "\n", + " # if M.edge_break_prob_wo_precursor.sum() <= 0.01:\n", " # drop_index.append(i)\n", - " #if M.edge_break_prob.sum() < 0.05: # TODO\n", + " # if M.edge_break_prob.sum() < 0.05: # TODO\n", " # drop_index.append(i)\n", - " \n", - " if M.match_stats[\"coverage\"] < 0.25: # Filter if total coverage is too low\n", + "\n", + " if M.match_stats[\"coverage\"] < 0.25: # Filter if total coverage is too low\n", " drop_index.append(i)\n", - " #if M.match_stats[\"coverage_wo_prec\"] < 0.1: # Filter if fragment coverage is too low (intensity wise)\n", + " # if M.match_stats[\"coverage_wo_prec\"] < 0.1: # Filter if fragment coverage is too low (intensity wise)\n", " # drop_index.append(i)\n", - " \n", + "\n", "# filter low res instruments TODO update to low quality spectra\n", "is_iontrap = df[\"Metabolite\"].apply(lambda x: x.metadata[\"instrument\"] == \"IT/ion trap\")\n", "drop_index += list(df[is_iontrap].index)\n", "\n", "fig, axs = plt.subplots(1, 3, figsize=(12.8, 4.2), sharey=True)\n", "\n", - "plt.ylim([-.02,1.02])\n", - "sns.boxplot(ax=axs[0], data=coverage_tracker, y=\"fragment_only_coverage\", color=color_palette[1])\n", + "plt.ylim([-0.02, 1.02])\n", + "sns.boxplot(\n", + " ax=axs[0], data=coverage_tracker, y=\"fragment_only_coverage\", color=color_palette[1]\n", + ")\n", "sns.boxplot(ax=axs[1], data=coverage_tracker, y=\"coverage\", color=color_palette[2])\n", "sns.swarmplot(ax=axs[2], data=coverage_tracker, y=\"coverage\", color=color_palette[2])\n", "axs[0].set_title(\"Coverage of peak intensity (fragments only)\")\n", @@ -1337,7 +1488,7 @@ "axs[2].set_title(\"Coverage of peak intensity\")\n", "plt.show()\n", "\n", - "print(f\"Filtering would drop {len(drop_index)} out of {df.shape[0]}\")\n" + "print(f\"Filtering would drop {len(drop_index)} out of {df.shape[0]}\")" ] }, { @@ -1515,7 +1666,9 @@ "import modules.IO.cfmReader as cfmReader\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", - "cf = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt\", as_df=True)\n", + "cf = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt\", as_df=True\n", + ")\n", "cf.head(2)" ] }, @@ -1554,7 +1707,9 @@ } ], "source": [ - "library_directory = f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Training_negative_mgf\"\n", + "library_directory = (\n", + " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Training_negative_mgf\"\n", + ")\n", "!ls $library_directory" ] }, @@ -1568,17 +1723,19 @@ "\n", "for file in os.listdir(library_directory):\n", " if file.endswith(\".mgf\"):\n", - " data = mgfReader.read(os.path.join(library_directory, file), sep=\"\\t\", debug=False)[0]\n", - " data['FILE'] = file\n", - " data['Precursor_type'] = \"[M-H]-\"\n", + " data = mgfReader.read(\n", + " os.path.join(library_directory, file), sep=\"\\t\", debug=False\n", + " )[0]\n", + " data[\"FILE\"] = file\n", + " data[\"Precursor_type\"] = \"[M-H]-\"\n", " df += [data]\n", "\n", "library_directory = library_directory.replace(\"negative\", \"positive\")\n", "for file in os.listdir(library_directory):\n", " if file.endswith(\".mgf\"):\n", " data = mgfReader.read(os.path.join(library_directory, file), sep=\"\\t\")[0]\n", - " data['FILE'] = file\n", - " data['Precursor_type'] = \"[M+H]+\"\n", + " data[\"FILE\"] = file\n", + " data[\"Precursor_type\"] = \"[M+H]+\"\n", " df += [data]\n", "\n", "df = pd.DataFrame(df)\n", @@ -1703,10 +1860,12 @@ } ], "source": [ - "solution = pd.read_csv(os.path.join(library_directory, \"..\", \"CASMI2016_Cat2and3_Training.csv\"))\n", + "solution = pd.read_csv(\n", + " os.path.join(library_directory, \"..\", \"CASMI2016_Cat2and3_Training.csv\")\n", + ")\n", "# solution = solution[solution[\"ION_MODE\"] == \" POSITIVE\"]\n", "# solution.reset_index(inplace=True, drop=True)\n", - "#df = pd.concat([df, solution], axis=1)\n", + "# df = pd.concat([df, solution], axis=1)\n", "\n", "df = pd.merge(df, solution, left_on=\"ChallengeName\", right_on=\"challengename\")\n", "\n", @@ -1752,6 +1911,7 @@ ], "source": [ "import seaborn as sns\n", + "\n", "print(df.groupby(\"Precursor_type\")[\"coverage100PPM\"].median())\n", "sns.kdeplot(data=df, x=\"coverage100PPM\", hue=\"Precursor_type\")" ] @@ -1782,17 +1942,18 @@ } ], "source": [ - "\n", - "library_candidates_directory = f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Training_Candidates\"\n", + "library_candidates_directory = (\n", + " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Training_Candidates\"\n", + ")\n", "candidates = []\n", "for file in os.listdir(library_candidates_directory):\n", " if file.endswith(\".csv\"):\n", " data = pd.read_csv(os.path.join(library_candidates_directory, file), sep=\",\")\n", " SMILES = list(data[\"SMILES\"])\n", - " d = {'ChallengeName': file.split(\".\")[0], 'cFILE': file, 'Candidates': SMILES}\n", + " d = {\"ChallengeName\": file.split(\".\")[0], \"cFILE\": file, \"Candidates\": SMILES}\n", " candidates += [d]\n", "\n", - "candidates = pd.DataFrame(candidates) #.loc[81:].reset_index(drop=True)\n", + "candidates = pd.DataFrame(candidates) # .loc[81:].reset_index(drop=True)\n", "print(candidates.head())" ] }, @@ -1813,10 +1974,8 @@ } ], "source": [ - "\n", - "\n", "df = pd.merge(df, candidates, on=\"ChallengeName\")\n", - "df.apply(lambda x: x[\"cFILE\"].split('.')[0] == x[\"ChallengeName\"], axis=1).all()" + "df.apply(lambda x: x[\"cFILE\"].split(\".\")[0] == x[\"ChallengeName\"], axis=1).all()" ] }, { @@ -1868,14 +2027,14 @@ "save_df = False\n", "name = \"casmi16t.csv\"\n", "\n", - "library_directory = '/'.join(library_directory.split('/')[:-1])\n", + "library_directory = \"/\".join(library_directory.split(\"/\")[:-1])\n", "print(library_directory)\n", "if save_df:\n", - " file = os.path.join(library_directory, name)\n", - " print(\"saving to \", file)\n", - " df.to_csv(file)\n", - " \n", - " #df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" + " file = os.path.join(library_directory, name)\n", + " print(\"saving to \", file)\n", + " df.to_csv(file)\n", + "\n", + " # df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" ] }, { @@ -1913,11 +2072,15 @@ "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", - " df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"][[\"ChallengeName\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")\n", - " \n", + " df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"][[\"ChallengeName\", \"SMILES\"]].to_csv(\n", + " file, index=False, header=False, sep=\" \"\n", + " )\n", + "\n", " name = name.replace(\"negative\", \"positive\")\n", " file = os.path.join(cfm_directory, name)\n", - " df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"][[\"ChallengeName\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")" + " df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"][[\"ChallengeName\", \"SMILES\"]].to_csv(\n", + " file, index=False, header=False, sep=\" \"\n", + " )" ] }, { @@ -1933,7 +2096,15 @@ "df[\"MOL\"] = df[\"SMILES\"].apply(Chem.MolFromSmiles)\n", "df[\"formula\"] = df[\"MOL\"].apply(rdMolDescriptors.CalcMolFormula)\n", "df[\"dataset\"] = \"CASMI16\"\n", - "df = df.rename(columns={\"FILE\": \"spec\", \"ChallengeName\": \"name\", \"Precursor_type\": \"ionization\", \"SMILES\": \"smiles\", \"INCHIKEY\": \"inchikey\"})" + "df = df.rename(\n", + " columns={\n", + " \"FILE\": \"spec\",\n", + " \"ChallengeName\": \"name\",\n", + " \"Precursor_type\": \"ionization\",\n", + " \"SMILES\": \"smiles\",\n", + " \"INCHIKEY\": \"inchikey\",\n", + " }\n", + ")" ] }, { @@ -1945,9 +2116,13 @@ "save_df = False\n", "if save_df:\n", " output_file = f\"{home}/data/metabolites/ms-pred/casmi16t_labels.tsv\"\n", - " df[[\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\",\t\"smiles\", \"inchikey\"]].to_csv(output_file, index=False, sep=\"\\t\")\n", + " df[\n", + " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", + " ].to_csv(output_file, index=False, sep=\"\\t\")\n", " output_file = f\"{home}/data/metabolites/ms-pred/casmi16t_positive_labels.tsv\"\n", - " df[df[\"ionization\"] == \"[M+H]+\"][[\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\",\t\"smiles\", \"inchikey\"]].to_csv(output_file, index=False, sep=\"\\t\")" + " df[df[\"ionization\"] == \"[M+H]+\"][\n", + " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", + " ].to_csv(output_file, index=False, sep=\"\\t\")" ] } ], diff --git a/lib_loader/casmi22_loader.ipynb b/lib_loader/casmi22_loader.ipynb index fffb4e4..648de06 100644 --- a/lib_loader/casmi22_loader.ipynb +++ b/lib_loader/casmi22_loader.ipynb @@ -15,12 +15,14 @@ ], "source": [ "import sys\n", - "print(f'Working with Python {sys.version}')\n", + "\n", + "print(f\"Working with Python {sys.version}\")\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import importlib\n", - "#import swifter\n", + "\n", + "# import swifter\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import collections\n", @@ -39,6 +41,7 @@ "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "import modules.IO.mspReader as mspReader\n", "import modules.IO.mgfReader as mgfReader\n", @@ -46,7 +49,7 @@ "import modules.IO.molReader as molReader\n", "\n", "\n", - "RDLogger.DisableLog('rdApp.*')\n" + "RDLogger.DisableLog(\"rdApp.*\")" ] }, { @@ -76,16 +79,19 @@ "import modules.visualization.spectrum_visualizer as sv\n", "from modules.MS.spectral_scores import spectral_cosine\n", "\n", + "\n", "def extract_compound_spectra(df, rt, precursor_mz, rt_tolerance=0.1, mz_tolerance=0.1):\n", " df[\"RT_dif\"] = abs(df[\"RT_min\"] - rt)\n", " r = df[\"RT_dif\"] < rt_tolerance\n", " p = abs(df[\"precursor_mz\"] - precursor_mz) < mz_tolerance\n", "\n", - " return df[np.logical_and(r,p)]\n", + " return df[np.logical_and(r, p)]\n", + "\n", "\n", "def score_against_merged(df, ref_spec):\n", " for i, d in df.iterrows():\n", - " df.at[i, 'ref_score'] = spectral_cosine(d[\"peaks\"], ref_spec, transform=np.sqrt)\n", + " df.at[i, \"ref_score\"] = spectral_cosine(d[\"peaks\"], ref_spec, transform=np.sqrt)\n", + "\n", "\n", "def merge_df(DF):\n", " if DF.shape[0] < 0:\n", @@ -98,58 +104,85 @@ " else:\n", " m = merge_spectrum(m, DF.iloc[i][\"peaks\"].copy(), merge_tolerance=0.01)\n", "\n", - " M = pd.DataFrame({\"RT\": DF[\"RT\"].mean(), \"RT_min\": DF[\"RT_min\"].mean(), \"precursor_mz\": DF[\"precursor_mz\"].mean(), \"Instrument_type\": \"HCD\", \"NCE\": DF[\"NCE\"].mean(), \"peaks\": [m]})\n", + " M = pd.DataFrame(\n", + " {\n", + " \"RT\": DF[\"RT\"].mean(),\n", + " \"RT_min\": DF[\"RT_min\"].mean(),\n", + " \"precursor_mz\": DF[\"precursor_mz\"].mean(),\n", + " \"Instrument_type\": \"HCD\",\n", + " \"NCE\": DF[\"NCE\"].mean(),\n", + " \"peaks\": [m],\n", + " }\n", + " )\n", " score_against_merged(DF, m)\n", " return M\n", - " \n", - "def extract_challenge(file: str, rt: float, precursor_mz: float, rt_tolerance:float=0.1, mz_tolerance:float=0.1, in_ppm=False, verbose: bool=False):\n", + "\n", + "\n", + "def extract_challenge(\n", + " file: str,\n", + " rt: float,\n", + " precursor_mz: float,\n", + " rt_tolerance: float = 0.1,\n", + " mz_tolerance: float = 0.1,\n", + " in_ppm=False,\n", + " verbose: bool = False,\n", + "):\n", " # Load data\n", " exp = MSExperiment()\n", " MzMLFile().load(file, exp)\n", - " \n", - " #Create dataframe with metadata\n", + "\n", + " # Create dataframe with metadata\n", " df = exp.get_df()\n", " df[\"RT_min\"] = df[\"RT\"] / 60.0\n", - " df[\"precursor_mz\"] = [spec.getAcquisitionInfo()[0].getMetaValue(\"[Thermo Trailer Extra]Monoisotopic M/Z:\") for spec in exp]\n", + " df[\"precursor_mz\"] = [\n", + " spec.getAcquisitionInfo()[0].getMetaValue(\n", + " \"[Thermo Trailer Extra]Monoisotopic M/Z:\"\n", + " )\n", + " for spec in exp\n", + " ]\n", " df[\"ms_level\"] = [spec.getMSLevel() for spec in exp]\n", " df[\"filter_string\"] = [spec.getMetaValue(\"filter string\") for spec in exp]\n", " df = df[df[\"ms_level\"] == 2]\n", " df[\"hcd\"] = df[\"filter_string\"].apply(lambda x: x.split(\"@\")[1].split(\" \")[0])\n", " df[\"Instrument_type\"] = \"HCD\"\n", " df[\"NCE\"] = df[\"hcd\"].apply(lambda x: x[3:]).astype(float)\n", - " df[\"peaks\"] = df.apply(lambda x: {\"mz\": list(x[\"mzarray\"]), \"intensity\": list(x[\"intarray\"])}, axis=1)\n", - " \n", - " \n", - " \n", + " df[\"peaks\"] = df.apply(\n", + " lambda x: {\"mz\": list(x[\"mzarray\"]), \"intensity\": list(x[\"intarray\"])}, axis=1\n", + " )\n", + "\n", " ## Extract compound spectra\n", " if in_ppm:\n", " mz_tolerance = mz_tolerance * precursor_mz\n", - " \n", - " df_extract = extract_compound_spectra(df, rt, precursor_mz, rt_tolerance=rt_tolerance, mz_tolerance=mz_tolerance)\n", + "\n", + " df_extract = extract_compound_spectra(\n", + " df, rt, precursor_mz, rt_tolerance=rt_tolerance, mz_tolerance=mz_tolerance\n", + " )\n", " df_extract[\"peaks\"].apply(normalize_spectrum)\n", "\n", - " if verbose: \n", - " print(f\"Extracted {df_extract.shape[0]} spectra with match rt and mz\" )\n", + " if verbose:\n", + " print(f\"Extracted {df_extract.shape[0]} spectra with match rt and mz\")\n", "\n", " df_low = df_extract[df_extract[\"NCE\"] == 35.0].sort_values(\"RT_dif\", ascending=True)\n", " df_med = df_extract[df_extract[\"NCE\"] == 45.0].sort_values(\"RT_dif\", ascending=True)\n", - " df_high = df_extract[df_extract[\"NCE\"] == 65.0].sort_values(\"RT_dif\", ascending=True)\n", - " \n", - " #print(df_low.shape[0],df_med.shape[0],df_high.shape[0] )\n", - " \n", + " df_high = df_extract[df_extract[\"NCE\"] == 65.0].sort_values(\n", + " \"RT_dif\", ascending=True\n", + " )\n", + "\n", + " # print(df_low.shape[0],df_med.shape[0],df_high.shape[0] )\n", + "\n", " challenges = pd.DataFrame()\n", " ref_scores = []\n", " if df_low.shape[0] > 0:\n", " low = merge_df(df_low)\n", " challenges = pd.concat([challenges, low])\n", " ref_scores += list(df_low[\"ref_score\"].values)\n", - " \n", + "\n", " if df_med.shape[0] > 0:\n", - " med = merge_df(df_med)\n", + " med = merge_df(df_med)\n", " challenges = pd.concat([challenges, med])\n", " ref_scores += list(df_med[\"ref_score\"].values)\n", "\n", - " if df_high.shape[0] > 0: \n", + " if df_high.shape[0] > 0:\n", " high = merge_df(df_high)\n", " challenges = pd.concat([challenges, high])\n", " ref_scores += list(df_high[\"ref_score\"].values)\n", @@ -160,23 +193,25 @@ " else:\n", " min_ref_score = min(ref_scores)\n", " if min_ref_score < 0.9:\n", - " print(f\"Warning: Low cosine score of {min_ref_score:.2f} detected. (All ref_scores {ref_scores})\")\n", - " #raise Warning(\"Low cosine score detected between merged spectrum and at least one experimental spectrum. Tolerance values might have picked up a false RT/MZ match\")\n", + " print(\n", + " f\"Warning: Low cosine score of {min_ref_score:.2f} detected. (All ref_scores {ref_scores})\"\n", + " )\n", + " # raise Warning(\"Low cosine score detected between merged spectrum and at least one experimental spectrum. Tolerance values might have picked up a false RT/MZ match\")\n", "\n", - " if verbose: \n", - " print(f\"\\nMerged {df_low.shape[0]} with NCE 35\" )\n", + " if verbose:\n", + " print(f\"\\nMerged {df_low.shape[0]} with NCE 35\")\n", " print(df_low[\"ref_score\"])\n", - " print(f\"\\nMerged {df_med.shape[0]} with NCE 45\" )\n", + " print(f\"\\nMerged {df_med.shape[0]} with NCE 45\")\n", " print(df_med[\"ref_score\"])\n", - " print(f\"\\nMerged {df_high.shape[0]} with NCE 65\" )\n", + " print(f\"\\nMerged {df_high.shape[0]} with NCE 65\")\n", " print(df_high[\"ref_score\"])\n", " sv.plot_spectrum(df_low.iloc[0], challenges.iloc[0])\n", " sv.plot_spectrum(df_med.iloc[0], challenges.iloc[1])\n", " sv.plot_spectrum(df_high.iloc[0], challenges.iloc[2])\n", " plt.show()\n", - " \n", + "\n", " print(f\"Minimum cosine score to merged spectra: {min_ref_score:.2f} (pass)\")\n", - " \n", + "\n", " return challenges, ref_scores" ] }, @@ -291,7 +326,6 @@ "path = os.path.join(l, f)\n", "\n", "\n", - "\n", "extract_challenge(path, 5.55, 719.2546, verbose=True)" ] }, @@ -428,7 +462,7 @@ "source": [ "l = f\"{home}/data/metabolites/CASMI_2022/\"\n", "key_file = \"MetSoc2022_CASMI_Workshop_Challenges_KEY_ALL_FINAL.csv\"\n", - "challenge_key = pd.read_csv(os.path.join(l, key_file), sep='\\t')\n", + "challenge_key = pd.read_csv(os.path.join(l, key_file), sep=\"\\t\")\n", "challenge_key.head(3)" ] }, @@ -582,8 +616,11 @@ } ], "source": [ - "\n", - "challenge_key = challenge_key[np.logical_or(challenge_key[\"Adduct\"] == \"[M+H]+\", challenge_key[\"Adduct\"] == \"[M-H]-\")]\n", + "challenge_key = challenge_key[\n", + " np.logical_or(\n", + " challenge_key[\"Adduct\"] == \"[M+H]+\", challenge_key[\"Adduct\"] == \"[M-H]-\"\n", + " )\n", + "]\n", "\n", "print(challenge_key.shape)\n", "\n", @@ -759,10 +796,18 @@ " adduct = d[\"Adduct\"]\n", " name = d[\"Compound Number\"]\n", " smiles = d[\"SMILES\"].strip()\n", - " \n", + "\n", " path = os.path.join(l, \"mzml/\", file)\n", - " \n", - " c, ref_scores = extract_challenge(path, rt, precursor_mz, rt_tolerance=2.5 / 60.0, mz_tolerance=10 * PPM, in_ppm=True, verbose=False) # 5 seconds torance, 10 ppm precursor mass\n", + "\n", + " c, ref_scores = extract_challenge(\n", + " path,\n", + " rt,\n", + " precursor_mz,\n", + " rt_tolerance=2.5 / 60.0,\n", + " mz_tolerance=10 * PPM,\n", + " in_ppm=True,\n", + " verbose=False,\n", + " ) # 5 seconds torance, 10 ppm precursor mass\n", " if c.shape[0] == 0:\n", " print(f\"No match found for Compound Number {name}\")\n", " misses += 1\n", @@ -771,10 +816,10 @@ " avg_min_ref_score += [min(ref_scores)]\n", " avg_num_spectra += [len(ref_scores)]\n", " c[\"Precursor_type\"] = adduct\n", - " c[\"ChallengeName\"] = \"Challenge-\"+str(name)\n", + " c[\"ChallengeName\"] = \"Challenge-\" + str(name)\n", " c[\"ChallengeRT\"] = rt\n", " c[\"SMILES\"] = smiles\n", - " challenges = pd.concat([challenges, c], axis=0)\n" + " challenges = pd.concat([challenges, c], axis=0)" ] }, { @@ -797,7 +842,9 @@ } ], "source": [ - "challenges[challenges[\"ChallengeName\"] == \"Challenge-277\"][\"SMILES\"]#.apply(lambda x: x.strip())" + "challenges[challenges[\"ChallengeName\"] == \"Challenge-277\"][\n", + " \"SMILES\"\n", + "] # .apply(lambda x: x.strip())" ] }, { @@ -997,12 +1044,11 @@ "name = \"casmi22_challenges_combined_accurate.csv\"\n", "\n", "if save_df:\n", - " file = os.path.join(l, name)\n", - " print(\"saving to \", file)\n", - " challenges.to_csv(file)\n", - " \n", - " #df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")\n", - "\n" + " file = os.path.join(l, name)\n", + " print(\"saving to \", file)\n", + " challenges.to_csv(file)\n", + "\n", + " # df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" ] }, { @@ -1023,17 +1069,21 @@ "save_df = False\n", "cfm_directory = f\"{home}/data/metabolites/cfm-id/\"\n", "name = \"casmi22_negative_solutions_cfm.txt\"\n", - "unique_challenges = challenges.drop_duplicates(subset='ChallengeName', keep='first')\n", + "unique_challenges = challenges.drop_duplicates(subset=\"ChallengeName\", keep=\"first\")\n", "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", " print(\"saving to \", file)\n", - " unique_challenges[unique_challenges[\"Precursor_type\"] == \"[M-H]-\"][[\"ChallengeName\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")\n", - " \n", + " unique_challenges[unique_challenges[\"Precursor_type\"] == \"[M-H]-\"][\n", + " [\"ChallengeName\", \"SMILES\"]\n", + " ].to_csv(file, index=False, header=False, sep=\" \")\n", + "\n", " name = name.replace(\"negative\", \"positive\")\n", " file = os.path.join(cfm_directory, name)\n", " print(\"saving to \", file)\n", - " unique_challenges[unique_challenges[\"Precursor_type\"] == \"[M+H]+\"][[\"ChallengeName\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")" + " unique_challenges[unique_challenges[\"Precursor_type\"] == \"[M+H]+\"][\n", + " [\"ChallengeName\", \"SMILES\"]\n", + " ].to_csv(file, index=False, header=False, sep=\" \")" ] }, { @@ -1117,7 +1167,7 @@ "\n", "print(challenge_key[challenge_key[\"Compound Number\"] == 78])\n", "\n", - "#print(extract_challenge(249.1496))" + "# print(extract_challenge(249.1496))" ] }, { @@ -1251,9 +1301,9 @@ } ], "source": [ - "vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) \n", - "np.add.at(vec, bins, peaks1[\"intensity\"]) #vec.put(bins, spec[\"intensity\"])\n", - "np.add.at(vec_ref, bins_ref, mm[\"intensity\"]) \n", + "vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map))\n", + "np.add.at(vec, bins, peaks1[\"intensity\"]) # vec.put(bins, spec[\"intensity\"])\n", + "np.add.at(vec_ref, bins_ref, mm[\"intensity\"])\n", "\n", "print(vec)\n", "print(vec_ref)" @@ -1306,7 +1356,7 @@ ], "source": [ "print(df_low.iloc[0][\"peaks\"][\"mz\"][-10:])\n", - "print(df_low.iloc[1][\"peaks\"][\"mz\"][-10:])\n" + "print(df_low.iloc[1][\"peaks\"][\"mz\"][-10:])" ] }, { @@ -1426,14 +1476,14 @@ } ], "source": [ - "#l = f\"{home}/data/metabolites/CASMI_2022/3_Data-20220516T091400Z-001/3_Data/mzML Data/1_Priority - Challenges 1-250/pos\"\n", + "# l = f\"{home}/data/metabolites/CASMI_2022/3_Data-20220516T091400Z-001/3_Data/mzML Data/1_Priority - Challenges 1-250/pos\"\n", "l = f\"{home}/data/metabolites/CASMI_2022/download/\"\n", "f = \"A_M3_posPFP_01.mzml\"\n", "\n", "run = pymzml.run.Reader(os.path.join(l, f))\n", "\n", "for spec in run:\n", - " mz, intensity = spec.peaks(\"centroided\")[:,0],spec.peaks(\"centroided\")[:,1] \n", + " mz, intensity = spec.peaks(\"centroided\")[:, 0], spec.peaks(\"centroided\")[:, 1]\n", " print(spec.scan_time)\n", " for i in spec.__dir__():\n", " if \"filter\" in i:\n", @@ -1487,6 +1537,7 @@ ], "source": [ "from pyteomics import mzml, auxiliary, mgf\n", + "\n", "path = os.path.join(l, f)\n", "reader = mzml.read(path)\n", "auxiliary.print_tree(next(reader))" @@ -1528,7 +1579,7 @@ "from matchms.importing import load_from_mzml\n", "from matchms.filtering import default_filters\n", "\n", - "reader = load_from_mzml(path)\n" + "reader = load_from_mzml(path)" ] }, { @@ -1552,7 +1603,7 @@ } ], "source": [ - "#print(next(reader).metadata)" + "# print(next(reader).metadata)" ] }, { @@ -1565,7 +1616,7 @@ "for spectrum in reader:\n", " # Apply default filter to standardize ion mode, correct charge and more.\n", " # Default filter is fully explained at https://matchms.readthedocs.io/en/latest/api/matchms.filtering.html .\n", - " #spectrum = default_filters(spectrum)\n", + " # spectrum = default_filters(spectrum)\n", " spectrums.append(spectrum)" ] }, @@ -1599,7 +1650,7 @@ "df[\"is_match\"] = np.logical_and(df[\"is_rt_match\"], df[\"is_mz_match\"])\n", "example = df[df[\"is_match\"]].iloc[0]\n", "spec = example[\"spectrum\"]\n", - "print()\n" + "print()" ] }, { @@ -1832,7 +1883,7 @@ "\n", "\n", "print(open(path).readlines()[197].strip())\n", - "print(path)\n" + "print(path)" ] }, { @@ -1921,10 +1972,11 @@ ], "source": [ "from pyopenms import *\n", + "\n", "exp = MSExperiment()\n", "MzMLFile().load(path, exp)\n", "\n", - "print( exp.getSpectrum(1).get_peaks()[0] )\n", + "print(exp.getSpectrum(1).get_peaks()[0])\n", "# [ 0. 2. 4. 6. 8. 10. 12. 14. 16. 18.]" ] }, @@ -1947,13 +1999,11 @@ "\n", "for spec in exp:\n", " if abs(spec.getRT() - 5.55) < 0.1:\n", - " #if spec.getMSLevel() == 2:\n", - " spec1 = spec\n", - " break\n", - "\n", - "print(spec1)\n", + " # if spec.getMSLevel() == 2:\n", + " spec1 = spec\n", + " break\n", "\n", - "\n" + "print(spec1)" ] }, { @@ -2015,7 +2065,6 @@ } ], "source": [ - "\n", "for g in getters:\n", " print(getattr(spec1, g))" ] @@ -2096,6 +2145,7 @@ "outputs": [], "source": [ "from pyopenms import *\n", + "\n", "exp = MSExperiment()\n", "MzMLFile().load(path, exp)\n", "\n", @@ -2617,8 +2667,9 @@ "\n", "for line in lines:\n", " line = line.strip()\n", - " if 'scan=2672' in line: print(line)\n", - " #if 'name=\"ms level\" value=\"2\"' in line: print(line)" + " if \"scan=2672\" in line:\n", + " print(line)\n", + " # if 'name=\"ms level\" value=\"2\"' in line: print(line)" ] }, { @@ -2665,7 +2716,7 @@ "print(spec.getMetaValue(\"filter string\"))\n", "print(spec.getMetaValue(\"[Thermo Trailer Extra]Monoisotopic M/Z:\"))\n", "print(spec.getMSLevel())\n", - "print(spec.getMSLevel() == 2)\n" + "print(spec.getMSLevel() == 2)" ] }, { @@ -2686,7 +2737,7 @@ ], "source": [ "a = spec.getPrecursors()\n", - "a\n" + "a" ] }, { @@ -3177,7 +3228,7 @@ "metadata": {}, "outputs": [], "source": [ - "scan=2672" + "scan = 2672" ] }, { @@ -3187,6 +3238,7 @@ "outputs": [], "source": [ "from pyopenms import *\n", + "\n", "exp = MSExperiment()\n", "MzMLFile().load(path, exp)\n", "\n", diff --git a/lib_loader/gnps_library_loader.ipynb b/lib_loader/gnps_library_loader.ipynb index f5b4c67..78fc9f3 100644 --- a/lib_loader/gnps_library_loader.ipynb +++ b/lib_loader/gnps_library_loader.ipynb @@ -31,12 +31,14 @@ ], "source": [ "import sys\n", - "print(f'Working with Python {sys.version}')\n", + "\n", + "print(f\"Working with Python {sys.version}\")\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import importlib\n", - "#import swifter\n", + "\n", + "# import swifter\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import collections\n", @@ -52,11 +54,14 @@ "\n", "# Deep Learning\n", "import sklearn\n", - "#import spektral\n", + "\n", + "# import spektral\n", "from sklearn.model_selection import train_test_split\n", + "\n", "# Keras\n", "from sklearn.model_selection import train_test_split\n", - "#import stellargraph as sg\n", + "\n", + "# import stellargraph as sg\n", "from rdkit import RDLogger\n", "\n", "\n", @@ -64,16 +69,17 @@ "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "import fiora.IO.mspReader as mspReader\n", "import fiora.visualization.spectrum_visualizer as sv\n", "import fiora.IO.molReader as molReader\n", "\n", "\n", - "RDLogger.DisableLog('rdApp.*')\n", + "RDLogger.DisableLog(\"rdApp.*\")\n", "\n", "\n", - "caffeine_smiles = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'\n", + "caffeine_smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", "caffeine_mol = Chem.MolFromSmiles(caffeine_smiles)\n", "\n", "caffeine_mol" @@ -131,7 +137,7 @@ } ], "source": [ - "df.columns#" + "df.columns #" ] }, { @@ -190,7 +196,7 @@ } ], "source": [ - "#TODO ANALYSE DATASET\n", + "# TODO ANALYSE DATASET\n", "\n", "sns.histplot(data=df, x=\"COLLISION_ENERGY\", hue=\"MS_MANUFACTURER\", multiple=\"dodge\")" ] @@ -246,11 +252,24 @@ "set_light_theme()\n", "\n", "# Define custom color palette\n", - " # Example colors\n", + "# Example colors\n", "\n", "# Create a FacetGrid with each subplot for a different manufacturer\n", - "g = sns.FacetGrid(df.dropna(subset=[\"COLLISION_ENERGY\"]), col=\"MS_MANUFACTURER\", col_wrap=3, sharex=True, sharey=False)\n", - "g.map(sns.histplot, \"COLLISION_ENERGY\", bins=range(0, 101, 10), kde=False, color=\"gray\", edgecolor=\"black\")\n", + "g = sns.FacetGrid(\n", + " df.dropna(subset=[\"COLLISION_ENERGY\"]),\n", + " col=\"MS_MANUFACTURER\",\n", + " col_wrap=3,\n", + " sharex=True,\n", + " sharey=False,\n", + ")\n", + "g.map(\n", + " sns.histplot,\n", + " \"COLLISION_ENERGY\",\n", + " bins=range(0, 101, 10),\n", + " kde=False,\n", + " color=\"gray\",\n", + " edgecolor=\"black\",\n", + ")\n", "\n", "# Set common labels and titles\n", "g.set_axis_labels(\"Collision Energy\", \"Count\")\n", @@ -418,10 +437,10 @@ "metadata": {}, "outputs": [], "source": [ - "#from PyCFMID.PyCFMID import fraggraph_gen\n", - "#caffeine_smiles = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'\n", + "# from PyCFMID.PyCFMID import fraggraph_gen\n", + "# caffeine_smiles = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'\n", "#\n", - "#frags = fraggraph_gen(caffeine_smiles, max_depth=2, ionization_mode='+', fullgraph=True, output_file=None)" + "# frags = fraggraph_gen(caffeine_smiles, max_depth=2, ionization_mode='+', fullgraph=True, output_file=None)" ] }, { @@ -455,12 +474,14 @@ } ], "source": [ - "nist_msp = mspReader.read(library_directory + library_name + '.MSP')\n", + "nist_msp = mspReader.read(library_directory + library_name + \".MSP\")\n", "df_nist = pd.DataFrame(nist_msp)\n", "\n", - "#df_nist['mol'] = df_nist['SMILES'].apply(Chem.MolFromSmiles)\n", - "#df_nist.dropna(inplace=True)\n", - "print(f\"Spectral file loaded with {df_nist.shape[0]} entries and {df_nist.shape[1]} variables\")\n" + "# df_nist['mol'] = df_nist['SMILES'].apply(Chem.MolFromSmiles)\n", + "# df_nist.dropna(inplace=True)\n", + "print(\n", + " f\"Spectral file loaded with {df_nist.shape[0]} entries and {df_nist.shape[1]} variables\"\n", + ")" ] }, { @@ -470,17 +491,17 @@ "outputs": [], "source": [ "# Search for Example\n", - "#example_entry = \"Desipramine\"\n", - "#x = df_nist[df_nist[\"Name\"] == example_entry]\n", - "#for i in x.index:\n", + "# example_entry = \"Desipramine\"\n", + "# x = df_nist[df_nist[\"Name\"] == example_entry]\n", + "# for i in x.index:\n", "# z = df_nist.loc[i]\n", - "# \n", + "#\n", "# print(z.ID, z.CE)\n", "# print(z[\"peaks\"][\"mz\"])\n", "# print(z[\"peaks\"][\"intensity\"])\n", "# print(\"--------------\")\n", - " \n", - "#print(x.loc[32271])" + "\n", + "# print(x.loc[32271])" ] }, { @@ -532,7 +553,7 @@ "source": [ "# Example\n", "example_entry = \"Desipramine\"\n", - "#x = df_nist[df_nist[\"Name\"] == example_entry].iloc[0]\n", + "# x = df_nist[df_nist[\"Name\"] == example_entry].iloc[0]\n", "EXAMPLE_ID = 32271\n", "x = df_nist.loc[EXAMPLE_ID]\n", "print(x)" @@ -558,10 +579,19 @@ "source": [ "# Define figure styles\n", "color_palette = sns.color_palette(\"magma_r\", 8)\n", - "sns.set_theme(style=\"whitegrid\",\n", - " rc={'axes.edgecolor': 'black', 'ytick.left': True, 'xtick.bottom': True, 'xtick.color': 'black',\n", - " \"axes.spines.bottom\": True, \"axes.spines.right\": True, \"axes.spines.top\": True,\n", - " \"axes.spines.left\": True})\n" + "sns.set_theme(\n", + " style=\"whitegrid\",\n", + " rc={\n", + " \"axes.edgecolor\": \"black\",\n", + " \"ytick.left\": True,\n", + " \"xtick.bottom\": True,\n", + " \"xtick.color\": \"black\",\n", + " \"axes.spines.bottom\": True,\n", + " \"axes.spines.right\": True,\n", + " \"axes.spines.top\": True,\n", + " \"axes.spines.left\": True,\n", + " },\n", + ")" ] }, { @@ -583,19 +613,28 @@ } ], "source": [ - "\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 2]}, sharey=True)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 2]}, sharey=True\n", + ")\n", "fig.set_tight_layout(False)\n", "for ax in axs:\n", - " ax.tick_params('x', labelrotation=45)\n", - "\n", - "sns.countplot(ax=axs[0], data=df_nist, x='Spectrum_type', edgecolor=\"black\", palette=color_palette)\n", - "sns.countplot(ax=axs[1], data=df_nist, x='Precursor_type', edgecolor=\"black\", palette=color_palette, order=df_nist['Precursor_type'].value_counts().iloc[:8].index)\n", + " ax.tick_params(\"x\", labelrotation=45)\n", + "\n", + "sns.countplot(\n", + " ax=axs[0], data=df_nist, x=\"Spectrum_type\", edgecolor=\"black\", palette=color_palette\n", + ")\n", + "sns.countplot(\n", + " ax=axs[1],\n", + " data=df_nist,\n", + " x=\"Precursor_type\",\n", + " edgecolor=\"black\",\n", + " palette=color_palette,\n", + " order=df_nist[\"Precursor_type\"].value_counts().iloc[:8].index,\n", + ")\n", "axs[0].set_ylim(0, 500000)\n", "axs[1].set_ylabel(\"\")\n", "\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -615,11 +654,13 @@ "# Filters\n", "df_nist = df_nist[df_nist[\"Spectrum_type\"] == \"MS2\"]\n", "target_precursor_type = [\"[M+H]+\", \"[M-H]-\", \"[M+H-H2O]+\", \"[M+Na]+\"]\n", - "df_nist = df_nist[df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in target_precursor_type)]\n", + "df_nist = df_nist[\n", + " df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in target_precursor_type)\n", + "]\n", "\n", "# Formats\n", - "df_nist['PrecursorMZ'] = df_nist[\"PrecursorMZ\"].astype('float')\n", - "df_nist['Num peaks'] = df_nist[\"Num peaks\"].astype('int')\n", + "df_nist[\"PrecursorMZ\"] = df_nist[\"PrecursorMZ\"].astype(\"float\")\n", + "df_nist[\"Num peaks\"] = df_nist[\"Num peaks\"].astype(\"int\")\n", "\n", "\n", "print(f\"Spectral file filtered down to {df_nist.shape[0]} entries\")" @@ -644,14 +685,25 @@ } ], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 2]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 2]}, sharey=False\n", + ")\n", "for ax in axs:\n", - " ax.tick_params('x', labelrotation=45)\n", - "\n", - "sns.boxplot(ax=axs[0], data=df_nist, y='PrecursorMZ', palette=color_palette, x=\"Precursor_type\")\n", - "sns.histplot(ax=axs[1], data=df_nist, x='Num peaks', color=color_palette[7], fill=True, edgecolor=\"black\")#, order=list(range(0,200)))\n", + " ax.tick_params(\"x\", labelrotation=45)\n", + "\n", + "sns.boxplot(\n", + " ax=axs[0], data=df_nist, y=\"PrecursorMZ\", palette=color_palette, x=\"Precursor_type\"\n", + ")\n", + "sns.histplot(\n", + " ax=axs[1],\n", + " data=df_nist,\n", + " x=\"Num peaks\",\n", + " color=color_palette[7],\n", + " fill=True,\n", + " edgecolor=\"black\",\n", + ") # , order=list(range(0,200)))\n", "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams['axes.edgecolor'] = 'black'\n", + "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", "axs[1].set_ylabel(\"\")\n", "axs[1].set_xlim([0, 100])\n", "axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", @@ -684,16 +736,20 @@ "x_mol = molReader.load_MOL(file)\n", "x_mol\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", - "axs[0].set_title(x[\"Name\"]+ \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", - "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n" + "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])" ] }, { @@ -713,35 +769,47 @@ } ], "source": [ + "# print(df_nist.loc[1474])\n", + "\n", + "print(\n", + " \"Reading structure information in MOL format from library files (this may take a while)\"\n", + ")\n", "\n", - "#print(df_nist.loc[1474])\n", "\n", - "print(\"Reading structure information in MOL format from library files (this may take a while)\")\n", "def fetch_mol(data):\n", - " file = library_directory + library_name + \".MOL/\" + \"S\" + str(data[\"CASNO\"]) + \".MOL\"\n", + " file = (\n", + " library_directory + library_name + \".MOL/\" + \"S\" + str(data[\"CASNO\"]) + \".MOL\"\n", + " )\n", " if not os.path.exists(file):\n", - " file = library_directory + library_name + \".MOL/\" + \"ID\" + str(data[\"ID\"]) + \".MOL\"\n", + " file = (\n", + " library_directory + library_name + \".MOL/\" + \"ID\" + str(data[\"ID\"]) + \".MOL\"\n", + " )\n", " return molReader.load_MOL(file)\n", "\n", - "df_nist= df_nist[~df_nist[\"InChIKey\"].isnull()] # Drop all without key (Not neccessarily neccesary)\n", + "\n", + "df_nist = df_nist[\n", + " ~df_nist[\"InChIKey\"].isnull()\n", + "] # Drop all without key (Not neccessarily neccesary)\n", "df_nist[\"MOL\"] = df_nist.apply(fetch_mol, axis=1)\n", - "print(f\"Successfully interpreted {sum(df_nist['MOL'].notna())} from {df_nist.shape[0]} entries. Dropping the rest.\")\n", + "print(\n", + " f\"Successfully interpreted {sum(df_nist['MOL'].notna())} from {df_nist.shape[0]} entries. Dropping the rest.\"\n", + ")\n", "\n", - "df_nist = df_nist[df_nist['MOL'].notna()]\n", + "df_nist = df_nist[df_nist[\"MOL\"].notna()]\n", "df_nist[\"SMILES\"] = df_nist[\"MOL\"].apply(Chem.MolToSmiles)\n", "df_nist[\"InChI\"] = df_nist[\"MOL\"].apply(Chem.MolToInchi)\n", "df_nist[\"K\"] = df_nist[\"MOL\"].apply(Chem.MolToInchiKey)\n", "df_nist[\"ExactMolWeight\"] = df_nist[\"MOL\"].apply(Chem.Descriptors.ExactMolWt)\n", "\n", - "#for i in df_nist.index:\n", + "# for i in df_nist.index:\n", "# tight_layout\n", "# x = df_nist.loc[i]\n", - "# \n", + "#\n", "# file = library_directory + library_name + \".MOL/\" + \"S\" + str(x[\"CASNO\"]) + \".MOL\"\n", "# if not os.path.exists(file):\n", "# file = library_directory + library_name + \".MOL/\" + \"ID\" + str(x[\"ID\"]) + \".MOL\"\n", "# print(x[\"ID\"], os.path.exists(file), os.path.exists(file))\n", - "# m = load_MOL(file)\n" + "# m = load_MOL(file)" ] }, { @@ -832,10 +900,16 @@ "source": [ "correct_keys = df_nist.apply(lambda x: x[\"InChIKey\"] == x[\"K\"], axis=1)\n", "s = \"confirmed!\" if correct_keys.all() else \"not confirmed !! Attention!\"\n", - "print(f\"Confirming whether computed and provided InChI-Keys are correct. Result: {s} ({correct_keys.sum()/len(correct_keys):0.2f} correct)\")\n", - "half_keys = df_nist.apply(lambda x: x[\"InChIKey\"].split('-')[0] == x[\"K\"].split('-')[0], axis=1)\n", + "print(\n", + " f\"Confirming whether computed and provided InChI-Keys are correct. Result: {s} ({correct_keys.sum() / len(correct_keys):0.2f} correct)\"\n", + ")\n", + "half_keys = df_nist.apply(\n", + " lambda x: x[\"InChIKey\"].split(\"-\")[0] == x[\"K\"].split(\"-\")[0], axis=1\n", + ")\n", "s = \"confirmed!\" if half_keys.all() else \"not confirmed !! Attention!\"\n", - "print(f\"Checking if main layer InChI-Keys are correct. Result: {s} ({half_keys.sum()/len(half_keys):0.3f} correct)\")\n", + "print(\n", + " f\"Checking if main layer InChI-Keys are correct. Result: {s} ({half_keys.sum() / len(half_keys):0.3f} correct)\"\n", + ")\n", "\n", "print(\"Dropping all other.\")\n", "df_nist[\"matching_key\"] = df_nist.apply(lambda x: x[\"InChIKey\"] == x[\"K\"], axis=1)\n", @@ -881,12 +955,17 @@ "from modules.MOL.constants import ADDUCT_WEIGHTS\n", "\n", "\n", - "\n", "df_nist = df_nist[df_nist[\"Num peaks\"] > MIN_PEAKS]\n", "df_nist = df_nist[df_nist[\"Num peaks\"] < MAX_PEAKS]\n", - "df_nist[\"theoretical_precursor_mz\"] = df_nist[\"ExactMolWeight\"] + df_nist[\"Precursor_type\"].map(ADDUCT_WEIGHTS)\n", - "df_nist = df_nist[df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in PRECURSOR_TYPES)]\n", - "df_nist[\"precursor_offset\"] = df_nist[\"PrecursorMZ\"] - df_nist[\"theoretical_precursor_mz\"]\n", + "df_nist[\"theoretical_precursor_mz\"] = df_nist[\"ExactMolWeight\"] + df_nist[\n", + " \"Precursor_type\"\n", + "].map(ADDUCT_WEIGHTS)\n", + "df_nist = df_nist[\n", + " df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in PRECURSOR_TYPES)\n", + "]\n", + "df_nist[\"precursor_offset\"] = (\n", + " df_nist[\"PrecursorMZ\"] - df_nist[\"theoretical_precursor_mz\"]\n", + ")\n", "\n", "print(f\"Shape {df_nist.shape}\")" ] @@ -917,17 +996,32 @@ } ], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1.5]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1.5]}, sharey=False\n", + ")\n", "for ax in axs:\n", - " ax.tick_params('x', labelrotation=45)\n", - "\n", - "sns.scatterplot(ax=axs[0], data=df_nist,x=\"precursor_offset\", y='PrecursorMZ', palette=color_palette)\n", - "sns.histplot(ax=axs[1], data=df_nist, x='precursor_offset', color=color_palette[7], fill=True, edgecolor=\"black\")#, order=list(range(0,200)))\n", + " ax.tick_params(\"x\", labelrotation=45)\n", + "\n", + "sns.scatterplot(\n", + " ax=axs[0],\n", + " data=df_nist,\n", + " x=\"precursor_offset\",\n", + " y=\"PrecursorMZ\",\n", + " palette=color_palette,\n", + ")\n", + "sns.histplot(\n", + " ax=axs[1],\n", + " data=df_nist,\n", + " x=\"precursor_offset\",\n", + " color=color_palette[7],\n", + " fill=True,\n", + " edgecolor=\"black\",\n", + ") # , order=list(range(0,200)))\n", "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams['axes.edgecolor'] = 'black'\n", + "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", "axs[1].set_ylabel(\"\")\n", - "#axs[1].set_xlim([0, 100])\n", - "#axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", + "# axs[1].set_xlim([0, 100])\n", + "# axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", "\n", "plt.show()" ] @@ -961,15 +1055,15 @@ } ], "source": [ - "\n", "# TODO Use more Collision energy types. eg. ramps, resonant...\n", "\n", - "import modules.MOL.collision_energy # TODO MOVE ALIGN CE BACK TO modules.MOL.\n", + "import modules.MOL.collision_energy # TODO MOVE ALIGN CE BACK TO modules.MOL.\n", "\n", "\n", "def NCE_to_eV(nce, precursor_mz, charge=1):\n", " return nce * precursor_mz / 500 * charge_factor[charge]\n", "\n", + "\n", "def align_CE(ce, precursor_mz):\n", " if type(ce) == float:\n", " return ce\n", @@ -977,45 +1071,52 @@ " ce = ce.replace(\"eV\", \"\")\n", " return float(ce)\n", " elif \"%\" in ce:\n", - " nce = ce.split('%')[0].strip().split(' ')[-1]\n", + " nce = ce.split(\"%\")[0].strip().split(\" \")[-1]\n", " try:\n", " nce = float(nce)\n", " return NCE_to_eV(nce, precursor_mz)\n", " except:\n", " return ce\n", " else:\n", - " try: \n", + " try:\n", " ce = float(ce)\n", " return ce\n", " except:\n", " return ce\n", "\n", - "charge_factor = {1: 1, 2: 0.9, 3: 0.85, 4: 0.8, 5: 0.75}\n", "\n", + "charge_factor = {1: 1, 2: 0.9, 3: 0.85, 4: 0.8, 5: 0.75}\n", "\n", "\n", - "df_nist[\"CE\"] = df_nist.apply(lambda x: align_CE(x[\"Collision_energy\"], x[\"theoretical_precursor_mz\"]), axis=1) #modules.MOL.collision_energy.align_CE) \n", + "df_nist[\"CE\"] = df_nist.apply(\n", + " lambda x: align_CE(x[\"Collision_energy\"], x[\"theoretical_precursor_mz\"]), axis=1\n", + ") # modules.MOL.collision_energy.align_CE)\n", "df_nist[\"CE_type\"] = df_nist[\"CE\"].apply(type)\n", - "df_nist[\"CE_derived_from_NCE\"] = df_nist[\"Collision_energy\"].apply(lambda x: \"%\" in str(x))\n", + "df_nist[\"CE_derived_from_NCE\"] = df_nist[\"Collision_energy\"].apply(\n", + " lambda x: \"%\" in str(x)\n", + ")\n", "# df_test = df_nist[df_nist[\"Collision_energy\"].apply(lambda x: \"%\" in str(x))][\"Collision_energy\"]\n", "# df_test = df_test.apply(lambda x: x.split('%')[0].strip().split(' ')[-1])\n", "# for d in df_test:\n", - "# try: \n", + "# try:\n", "# float(d)\n", "# except:\n", "# print(d)\n", "\n", "\n", - "\n", - "print(\"Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)\")\n", + "print(\n", + " \"Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)\"\n", + ")\n", "print(df_nist[\"CE_type\"].value_counts())\n", "\n", "print(\"Removing all but absolute values\")\n", "df_nist = df_nist[df_nist[\"CE_type\"] == float]\n", "df_nist = df_nist[~df_nist[\"CE\"].isnull()]\n", - "#len(df_nist['CE'].unique())\n", + "# len(df_nist['CE'].unique())\n", "\n", - "print(f'Detected {len(df_nist[\"CE\"].unique())} unique collision energies in range from {np.min(df_nist[\"CE\"])} to {max(df_nist[\"CE\"])} eV')\n" + "print(\n", + " f\"Detected {len(df_nist['CE'].unique())} unique collision energies in range from {np.min(df_nist['CE'])} to {max(df_nist['CE'])} eV\"\n", + ")" ] }, { @@ -1045,13 +1146,24 @@ ], "source": [ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", - "#for ax in axs:\n", + "# for ax in axs:\n", "# ax.tick_params('x', labelrotation=45)\n", "\n", - "#sns.scatterplot(ax=axs[0], data=df_nist,x=\"precursor_offset\", y='PrecursorMZ', palette=color_palette)\n", - "sns.histplot(ax=ax, data=df_nist, x='CE', hue=\"CE_derived_from_NCE\", palette=[color_palette[4], color_palette[2]], multiple=\"stack\", fill=True, binwidth=2, edgecolor=\"black\", binrange=[0, 200])#, order=list(range(0,200)))\n", + "# sns.scatterplot(ax=axs[0], data=df_nist,x=\"precursor_offset\", y='PrecursorMZ', palette=color_palette)\n", + "sns.histplot(\n", + " ax=ax,\n", + " data=df_nist,\n", + " x=\"CE\",\n", + " hue=\"CE_derived_from_NCE\",\n", + " palette=[color_palette[4], color_palette[2]],\n", + " multiple=\"stack\",\n", + " fill=True,\n", + " binwidth=2,\n", + " edgecolor=\"black\",\n", + " binrange=[0, 200],\n", + ") # , order=list(range(0,200)))\n", "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams['axes.edgecolor'] = 'black'\n", + "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", "plt.show()\n", "print(f\"{df_nist.shape[0]} spectra remaining with aligned absolute collision energies\")" ] @@ -1075,6 +1187,7 @@ "%%capture\n", "from modules.MOL.Metabolite import Metabolite\n", "from modules.MOL.constants import PPM\n", + "\n", "TOLERANCE = 200 * PPM\n", "\n", "\n", @@ -1082,7 +1195,12 @@ "df_nist[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "df_nist[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes())\n", "df_nist[\"Metabolite\"].apply(lambda x: x.fragment_MOL())\n", - "df_nist.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE), axis=1)" + "df_nist.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -1091,9 +1209,17 @@ "metadata": {}, "outputs": [], "source": [ - "from modules.MOL.mol_graph import mol_to_graph, get_adjacency_matrix, get_degree_matrix, get_edges, get_identity_matrix, draw_graph\n", + "from modules.MOL.mol_graph import (\n", + " mol_to_graph,\n", + " get_adjacency_matrix,\n", + " get_degree_matrix,\n", + " get_edges,\n", + " get_identity_matrix,\n", + " draw_graph,\n", + ")\n", "\n", "from modules.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", + "\n", "node_encoder = AtomFeatureEncoder()" ] }, @@ -1103,29 +1229,34 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "num_elems = 12\n", + "\n", + "\n", "def add_dataframe_features(df):\n", - " df['graph'] = df['MOL'].apply(mol_to_graph)\n", - " #df['features'] = df['graph'].apply(node_encoder.encode)\n", - " df['A'] = df['graph'].apply(get_adjacency_matrix)\n", - " df['Atilde'] = df['A'].apply(lambda x: x + np.eye(N=x.shape[0]))\n", - " df['Id'] = df['A'].apply(get_identity_matrix)\n", - " df['deg'] = df['A'].apply(get_degree_matrix)\n", - " df['is_aromatic'] = df['graph'].apply(lambda x: np.array([[x.nodes[atom]['is_aromatic'] for atom in x.nodes()]]).T)\n", + " df[\"graph\"] = df[\"MOL\"].apply(mol_to_graph)\n", + " # df['features'] = df['graph'].apply(node_encoder.encode)\n", + " df[\"A\"] = df[\"graph\"].apply(get_adjacency_matrix)\n", + " df[\"Atilde\"] = df[\"A\"].apply(lambda x: x + np.eye(N=x.shape[0]))\n", + " df[\"Id\"] = df[\"A\"].apply(get_identity_matrix)\n", + " df[\"deg\"] = df[\"A\"].apply(get_degree_matrix)\n", + " df[\"is_aromatic\"] = df[\"graph\"].apply(\n", + " lambda x: np.array([[x.nodes[atom][\"is_aromatic\"] for atom in x.nodes()]]).T\n", + " )\n", "\n", " # Extras\n", "\n", - " #df['Xsymbol'] = df['graph'].apply(lambda x: [x.nodes[atom]['atom_symbol'] for atom in x.nodes()])\n", - " #df['Xi'] = df['graph'].apply(lambda x: [min(x.nodes[atom]['atomic_num'], num_elems - 1) for atom in x.nodes()])\n", - " #df['X'] = df['Xi'].apply(lambda x: to_categorical(x, num_classes=num_elems))\n", - " \n", - " #df['isN'] = df['graph'].apply(lambda x: np.array([[int(x.nodes[atom]['atom_symbol'] == 'N') for atom in x.nodes()]]))\n", - " #df['isN_in_radius1'] = [df.loc[i, 'Atilde'] * df.loc[i,'isN'].T for i in df.index]\n", - " #df['isN_in_radius1'] = df['isN_in_radius1'].apply(lambda x: x.clip(0, 1))\n", - " #df['isN_neighboring'] = [df.loc[i, 'A'] * df.loc[i,'isN'].T for i in df.index]\n", - " #df['isN_neighboring'] = df['isN_neighboring'].apply(lambda x: x.clip(0, 1))\n", + " # df['Xsymbol'] = df['graph'].apply(lambda x: [x.nodes[atom]['atom_symbol'] for atom in x.nodes()])\n", + " # df['Xi'] = df['graph'].apply(lambda x: [min(x.nodes[atom]['atomic_num'], num_elems - 1) for atom in x.nodes()])\n", + " # df['X'] = df['Xi'].apply(lambda x: to_categorical(x, num_classes=num_elems))\n", + "\n", + " # df['isN'] = df['graph'].apply(lambda x: np.array([[int(x.nodes[atom]['atom_symbol'] == 'N') for atom in x.nodes()]]))\n", + " # df['isN_in_radius1'] = [df.loc[i, 'Atilde'] * df.loc[i,'isN'].T for i in df.index]\n", + " # df['isN_in_radius1'] = df['isN_in_radius1'].apply(lambda x: x.clip(0, 1))\n", + " # df['isN_neighboring'] = [df.loc[i, 'A'] * df.loc[i,'isN'].T for i in df.index]\n", + " # df['isN_neighboring'] = df['isN_neighboring'].apply(lambda x: x.clip(0, 1))\n", " return df\n", + "\n", + "\n", "df_nist = add_dataframe_features(df_nist)" ] }, @@ -1155,13 +1286,17 @@ "source": [ "x = df_nist.loc[EXAMPLE_ID]\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", - "axs[0].set_title(x[\"Name\"]+ \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", "\n", @@ -1177,8 +1312,8 @@ "source": [ "import networkx as nx\n", "\n", - "#TODO Refactor compute_helper_matrices and so on\n", - "#def add_dataframe_edge_features(df):\n", + "# TODO Refactor compute_helper_matrices and so on\n", + "# def add_dataframe_edge_features(df):\n", "# df['AL'] = df.apply(lambda x: compute_helper_matrices(nx.convert_matrix.to_numpy_matrix(x[\"graph\"]), x['deg'], x['graph']), axis=1)\n", "# df['AR'] = df['AL'].apply(lambda x: x[1])\n", "# df['edges_is_aromatic'] = df['AL'].apply(lambda x: np.array([x[2]]).T)\n", @@ -1186,7 +1321,7 @@ "# df['AL'] = df['AL'].apply(lambda x: x[0])\n", "# return df\n", "\n", - "#df_nist = add_dataframe_edge_features(df_nist)" + "# df_nist = add_dataframe_edge_features(df_nist)" ] }, { @@ -1244,6 +1379,7 @@ ], "source": [ "from modules.MOL.FragmentationTree import FragmentationTree\n", + "\n", "importlib.reload(modules.MOL.FragmentationTree)\n", "importlib.reload(modules.MOL.mol_graph)\n", "\n", @@ -1300,14 +1436,18 @@ ], "source": [ "import modules.IO.fraggraphReader as fraggraphReader\n", + "\n", "importlib.reload(modules.MS.ms_utility)\n", "from modules.MS.ms_utility import find_matching_peaks, match_fragment_lists\n", + "\n", "importlib.reload(modules.IO.fraggraphReader)\n", "importlib.reload(modules.MOL.FragmentationTree)\n", "from modules.MOL.FragmentationTree import FragmentationTree\n", "\n", "\n", - "f = fraggraphReader.parser_fraggraph_gen(library_directory + \"examples/CNCCCN1c2ccccc2CCc2ccccc21_fraggraph.txt\")\n", + "f = fraggraphReader.parser_fraggraph_gen(\n", + " library_directory + \"examples/CNCCCN1c2ccccc2CCc2ccccc21_fraggraph.txt\"\n", + ")\n", "x = df_nist.loc[EXAMPLE_ID]\n", "\n", "match_fragment_lists(x[\"peaks\"][\"mz\"], f[\"fragments\"][\"mass\"])" @@ -1367,26 +1507,29 @@ } ], "source": [ - "\n", "x = df_nist.loc[EXAMPLE_ID]\n", "\n", "FT = x[\"Metabolite\"].fragmentation_tree\n", - "#frag.build_fragmentation_tree_by_rotatable_bond_breaks()\n", + "# frag.build_fragmentation_tree_by_rotatable_bond_breaks()\n", "print(FT)\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", "\n", "img = Chem.Draw.MolToImage(x[\"MOL\"], ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", - "axs[0].set_title(x[\"Name\"]+ \" structure:\\n\" + x[\"SMILES\"])\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + x[\"SMILES\"])\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n", "\n", "print(\"Matching peaks to fragments\")\n", - "print(x[\"Metabolite\"].peak_matches)\n" + "print(x[\"Metabolite\"].peak_matches)" ] }, { @@ -1422,13 +1565,18 @@ } ], "source": [ - "\n", - "df_nist[\"peak_matches\"] = df_nist[\"Metabolite\"].apply(lambda x: getattr(x, \"peak_matches\"))\n", + "df_nist[\"peak_matches\"] = df_nist[\"Metabolite\"].apply(\n", + " lambda x: getattr(x, \"peak_matches\")\n", + ")\n", "df_nist[\"num_peaks_matched\"] = df_nist[\"peak_matches\"].apply(len)\n", "\n", "\n", "def get_match_stats(matches):\n", - " num_unique, num_conflicts, mode_count = 0, 0, {\"[M+H]+\": 0, \"[M-H]+\": 0, \"[M-3H]+\": 0}\n", + " num_unique, num_conflicts, mode_count = (\n", + " 0,\n", + " 0,\n", + " {\"[M+H]+\": 0, \"[M-H]+\": 0, \"[M-3H]+\": 0},\n", + " )\n", " for mz, match_data in matches.items():\n", " candidates = match_data[\"fragments\"]\n", " if len(candidates) == 1:\n", @@ -1438,21 +1586,27 @@ " for c in candidates:\n", " mode_count[c[1][0]] += 1\n", " return num_unique, num_conflicts, mode_count\n", - "d = df_nist.loc[EXAMPLE_ID]\n", "\n", "\n", + "d = df_nist.loc[EXAMPLE_ID]\n", + "\n", "\n", "df_nist[\"match_stats\"] = df_nist[\"peak_matches\"].apply(lambda x: get_match_stats(x))\n", - "df_nist[\"num_unique_peaks_matched\"] = df_nist.apply(lambda x: x[\"match_stats\"][0], axis=1)\n", - "df_nist[\"num_conflicts_in_peak_matching\"] = df_nist.apply(lambda x: x[\"match_stats\"][1], axis=1)\n", + "df_nist[\"num_unique_peaks_matched\"] = df_nist.apply(\n", + " lambda x: x[\"match_stats\"][0], axis=1\n", + ")\n", + "df_nist[\"num_conflicts_in_peak_matching\"] = df_nist.apply(\n", + " lambda x: x[\"match_stats\"][1], axis=1\n", + ")\n", "df_nist[\"match_mode_counts\"] = df_nist.apply(lambda x: x[\"match_stats\"][2], axis=1)\n", - "u= df_nist[\"num_unique_peaks_matched\"].sum() \n", - "s= df_nist[\"num_conflicts_in_peak_matching\"].sum() \n", - "print(f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u+s):.02f} %))\")\n", + "u = df_nist[\"num_unique_peaks_matched\"].sum()\n", + "s = df_nist[\"num_conflicts_in_peak_matching\"].sum()\n", + "print(\n", + " f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))\"\n", + ")\n", "print(f\"Total number of conflicting peak to fragment matches: {s}\")\n", "\n", - "df_nist.shape\n", - " " + "df_nist.shape" ] }, { @@ -1477,19 +1631,30 @@ "fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", "fig.suptitle(f\"Identified peaks with fragment offset\")\n", - "#plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", - "sns.histplot(ax=axs[0],data=df_nist, x=\"num_peaks_matched\", color=color_palette[0], edgecolor=\"black\", bins=range(0,20, 1))\n", - "#axs[0].set_ylim(-0.5, 10)\n", + "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", + "sns.histplot(\n", + " ax=axs[0],\n", + " data=df_nist,\n", + " x=\"num_peaks_matched\",\n", + " color=color_palette[0],\n", + " edgecolor=\"black\",\n", + " bins=range(0, 20, 1),\n", + ")\n", + "# axs[0].set_ylim(-0.5, 10)\n", "axs[0].set_ylabel(\"peaks identified\")\n", "\n", "\n", - "sns.boxplot(ax=axs[1],data=df_nist, y=\"num_unique_peaks_matched\", color=color_palette[1])\n", + "sns.boxplot(\n", + " ax=axs[1], data=df_nist, y=\"num_unique_peaks_matched\", color=color_palette[1]\n", + ")\n", "axs[1].set_ylim(-0.5, 15)\n", "axs[1].set_xlabel(\"unique matches\")\n", "axs[1].set_ylabel(\"\")\n", "\n", "\n", - "sns.boxplot(ax=axs[2],data=df_nist, y=\"num_conflicts_in_peak_matching\", color=color_palette[3])\n", + "sns.boxplot(\n", + " ax=axs[2], data=df_nist, y=\"num_conflicts_in_peak_matching\", color=color_palette[3]\n", + ")\n", "axs[2].set_ylim(-0.5, 15)\n", "axs[2].set_xlabel(\"conflicts\")\n", "axs[2].set_ylabel(\"\")\n", @@ -1520,14 +1685,28 @@ "\n", "mode_counts = {\"[M+H]+\": 0, \"[M-H]+\": 0, \"[M-3H]+\": 0}\n", "\n", + "\n", "def update_mode_counts(m):\n", " for mode in m.keys():\n", " mode_counts[mode] += m[mode]\n", "\n", + "\n", "df_nist[\"match_mode_counts\"].apply(update_mode_counts)\n", "\n", - "sns.barplot(ax=axs[0], x=list(mode_counts.keys()), y=[mode_counts[k] for k in mode_counts.keys()], palette=color_palette, edgecolor=\"black\", linewidth=1.5)\n", - "axs[1].pie([mode_counts[k] for k in mode_counts.keys()], labels=list(mode_counts.keys()), colors=color_palette, wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5})\n", + "sns.barplot(\n", + " ax=axs[0],\n", + " x=list(mode_counts.keys()),\n", + " y=[mode_counts[k] for k in mode_counts.keys()],\n", + " palette=color_palette,\n", + " edgecolor=\"black\",\n", + " linewidth=1.5,\n", + ")\n", + "axs[1].pie(\n", + " [mode_counts[k] for k in mode_counts.keys()],\n", + " labels=list(mode_counts.keys()),\n", + " colors=color_palette,\n", + " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + ")\n", "\n", "plt.show()" ] @@ -1551,9 +1730,11 @@ } ], "source": [ - "\n", - "for i in range(0,6):\n", - " print(f\"Minimum {i} unique peaks identified (including precursors): \", (df_nist[\"num_unique_peaks_matched\"] >= i).sum())\n" + "for i in range(0, 6):\n", + " print(\n", + " f\"Minimum {i} unique peaks identified (including precursors): \",\n", + " (df_nist[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " )" ] }, { @@ -1576,22 +1757,58 @@ "min_peaks = 5\n", "\n", "if save_df:\n", - " key_columns = ['Name', 'Synon', 'Notes', 'Precursor_type', 'Spectrum_type',\n", - " 'PrecursorMZ', 'Instrument_type', 'Instrument', 'Sample_inlet',\n", - " 'Ionization', 'Collision_energy', 'Ion_mode', 'Special_fragmentation',\n", - " 'InChIKey', 'Formula', 'MW', 'ExactMass', 'CASNO', 'NISTNO', 'ID',\n", - " 'Comment', 'Num peaks', 'peaks', 'Link', 'Related_CAS#',\n", - " 'Collision_gas', 'Pressure', 'In-source_voltage', 'msN_pathway', 'MOL',\n", - " 'SMILES', 'InChI', 'K', 'ExactMolWeight', 'matching_key',\n", - " 'theoretical_precursor_mz', 'precursor_offset', 'CE', 'CE_type', 'peak_matches',\n", - " 'num_peaks_matched', 'match_stats', 'num_unique_peaks_matched',\n", - " 'num_conflicts_in_peak_matching', 'match_mode_counts']\n", - " file = library_directory + name + \"_min\" + str(min_peaks) + \"_\" + date + \".csv\"\n", - " print(\"saving to \", file)\n", - " df_nist[df_nist[\"num_unique_peaks_matched\"] >= min_peaks][key_columns].to_csv(file)\n", - " \n", - " #df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")\n", - "\n" + " key_columns = [\n", + " \"Name\",\n", + " \"Synon\",\n", + " \"Notes\",\n", + " \"Precursor_type\",\n", + " \"Spectrum_type\",\n", + " \"PrecursorMZ\",\n", + " \"Instrument_type\",\n", + " \"Instrument\",\n", + " \"Sample_inlet\",\n", + " \"Ionization\",\n", + " \"Collision_energy\",\n", + " \"Ion_mode\",\n", + " \"Special_fragmentation\",\n", + " \"InChIKey\",\n", + " \"Formula\",\n", + " \"MW\",\n", + " \"ExactMass\",\n", + " \"CASNO\",\n", + " \"NISTNO\",\n", + " \"ID\",\n", + " \"Comment\",\n", + " \"Num peaks\",\n", + " \"peaks\",\n", + " \"Link\",\n", + " \"Related_CAS#\",\n", + " \"Collision_gas\",\n", + " \"Pressure\",\n", + " \"In-source_voltage\",\n", + " \"msN_pathway\",\n", + " \"MOL\",\n", + " \"SMILES\",\n", + " \"InChI\",\n", + " \"K\",\n", + " \"ExactMolWeight\",\n", + " \"matching_key\",\n", + " \"theoretical_precursor_mz\",\n", + " \"precursor_offset\",\n", + " \"CE\",\n", + " \"CE_type\",\n", + " \"peak_matches\",\n", + " \"num_peaks_matched\",\n", + " \"match_stats\",\n", + " \"num_unique_peaks_matched\",\n", + " \"num_conflicts_in_peak_matching\",\n", + " \"match_mode_counts\",\n", + " ]\n", + " file = library_directory + name + \"_min\" + str(min_peaks) + \"_\" + date + \".csv\"\n", + " print(\"saving to \", file)\n", + " df_nist[df_nist[\"num_unique_peaks_matched\"] >= min_peaks][key_columns].to_csv(file)\n", + "\n", + " # df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" ] }, { @@ -1626,8 +1843,8 @@ "from torchmetrics import Accuracy, MetricTracker\n", "from modules.GNN.MLPEdgeClassifier import MLPEdgeClassifier\n", "from modules.GNN.GNNModels import GCNNodeClassifier\n", - "#importlib.reload(modules.GNN.GNNModels)\n", - "#importlib.reload(modules.GNN.GNNLayers)\n" + "# importlib.reload(modules.GNN.GNNModels)\n", + "# importlib.reload(modules.GNN.GNNLayers)" ] }, { @@ -1636,21 +1853,19 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def train(model, dataloader_training, optimizer, loss_fn, tracker, epochs=10):\n", " for e in range(epochs):\n", - " print(f'Epoch {e + 1}/{epochs}')\n", + " print(f\"Epoch {e + 1}/{epochs}\")\n", " training_loss = 0\n", " tracker.increment()\n", " for batch_id, (X, y) in enumerate(dataloader_training):\n", - " \n", " # Complete the implementation.\n", " model.train()\n", " # Compute predictions based on the current set of parameters\n", " y_pred = model(X)\n", " # Compute prediction error\n", " loss = loss_fn(y_pred, y)\n", - " tracker.update(y, (y_pred>0).int())\n", + " tracker.update(y, (y_pred > 0).int())\n", " # Reset partial derivatives\n", " optimizer.zero_grad()\n", " # Compute partial derivatives\n", @@ -1660,61 +1875,70 @@ "\n", " # Record loss\n", " training_loss += loss.item()\n", - " print(f'Avg. training loss {training_loss / (batch_id + 1) :>.3f}', end='\\r')\n", + " print(f\"Avg. training loss {training_loss / (batch_id + 1):>.3f}\", end=\"\\r\")\n", "\n", - " print('')\n", - " print(f'Accuracy: {tracker.compute():>.4f}')\n", + " print(\"\")\n", + " print(f\"Accuracy: {tracker.compute():>.4f}\")\n", " training_loss /= len(dataloader_training)\n", "\n", " return training_loss\n", "\n", + "\n", "def validate_gnn(model, dataloader_val, loss_fn, tracker):\n", " tracker.increment()\n", " validation_loss = 0\n", "\n", " with torch.no_grad():\n", " for batch_id, (X, A, y) in enumerate(dataloader_val):\n", - " model.eval()\n", - " y_pred = model(X, A)\n", - " loss = loss_fn(y_pred, y)\n", - "\n", - " tracker.update(y_pred[0], y[0].int())\n", - " validation_loss += loss.item()\n", + " model.eval()\n", + " y_pred = model(X, A)\n", + " loss = loss_fn(y_pred, y)\n", "\n", - " if batch_id == 500: break\n", + " tracker.update(y_pred[0], y[0].int())\n", + " validation_loss += loss.item()\n", "\n", + " if batch_id == 500:\n", + " break\n", "\n", " val_accuracy = tracker.compute()\n", " validation_loss /= len(dataloader_val)\n", - " print(f' Validation Accuracy: {val_accuracy:>.3f} (Loss: {validation_loss:>.3f})')\n", + " print(f\" Validation Accuracy: {val_accuracy:>.3f} (Loss: {validation_loss:>.3f})\")\n", "\n", " return val_accuracy\n", "\n", "\n", - "def train_gnn(model, training_data, optimizer, loss_fn, tracker, batch_size=1, epochs=5):\n", + "def train_gnn(\n", + " model, training_data, optimizer, loss_fn, tracker, batch_size=1, epochs=5\n", + "):\n", " acc = []\n", - " \n", + "\n", " train_proportion = 0.8\n", - " train_size = int(len(training_data)*train_proportion)\n", - " training_data, validation_data = torch.utils.data.random_split(training_data, [train_size, len(training_data) - train_size], generator=torch.Generator().manual_seed(42))\n", - " dataloader_training, dataloader_val = DataLoader(training_data, batch_size=batch_size), DataLoader(validation_data, batch_size=batch_size)\n", + " train_size = int(len(training_data) * train_proportion)\n", + " training_data, validation_data = torch.utils.data.random_split(\n", + " training_data,\n", + " [train_size, len(training_data) - train_size],\n", + " generator=torch.Generator().manual_seed(42),\n", + " )\n", + " dataloader_training, dataloader_val = (\n", + " DataLoader(training_data, batch_size=batch_size),\n", + " DataLoader(validation_data, batch_size=batch_size),\n", + " )\n", "\n", " for e in range(epochs):\n", - " print(f'Epoch {e + 1}/{epochs}')\n", + " print(f\"Epoch {e + 1}/{epochs}\")\n", " training_loss = 0\n", " tracker.increment()\n", " for batch_id, (X, A, y) in enumerate(dataloader_training):\n", - " \n", " # Feed forward\n", " model.train()\n", " y_pred = model(X, A)\n", - " #print(y_pred)\n", - " #print(y)\n", - " #print(errör)\n", + " # print(y_pred)\n", + " # print(y)\n", + " # print(errör)\n", " loss = loss_fn(y_pred, y)\n", - " #tracker.update(y[0], (y_pred[0]>0).int())\n", - " #tracker.update(y[0], (y_pred>0).int())\n", - " tracker.update(y_pred[0], y[0].int()) # with logits\n", + " # tracker.update(y[0], (y_pred[0]>0).int())\n", + " # tracker.update(y[0], (y_pred>0).int())\n", + " tracker.update(y_pred[0], y[0].int()) # with logits\n", "\n", " # Backpropagate\n", " optimizer.zero_grad()\n", @@ -1724,13 +1948,17 @@ " # Record loss\n", " training_loss += loss.item()\n", " if batch_id % 100 == 0:\n", - " print(f' Avg. training loss {training_loss / (batch_id + 1) :>.4f}', end='\\r')\n", - " if batch_id == 2000: break\n", + " print(\n", + " f\" Avg. training loss {training_loss / (batch_id + 1):>.4f}\",\n", + " end=\"\\r\",\n", + " )\n", + " if batch_id == 2000:\n", + " break\n", " # On epoch end: Evaluation\n", " accuracy = tracker.compute()\n", " acc.append(accuracy)\n", " training_loss /= len(dataloader_training)\n", - " print(f' Training Accuracy: {accuracy:>.3f} (Loss: {training_loss:>.3f})')\n", + " print(f\" Training Accuracy: {accuracy:>.3f} (Loss: {training_loss:>.3f})\")\n", " validate_gnn(model, dataloader_val, loss_fn, tracker)\n", "\n", " return acc" @@ -1744,13 +1972,13 @@ "source": [ "class AtomAromaticityData(Dataset):\n", " def __init__(self) -> None:\n", - " #super().__init__()\n", - " self.X = np.concatenate(df_nist[\"features\"].values, dtype='float32')\n", - " self.y = np.concatenate(df_nist[\"is_aromatic\"].values*1, dtype='float32')\n", - " \n", + " # super().__init__()\n", + " self.X = np.concatenate(df_nist[\"features\"].values, dtype=\"float32\")\n", + " self.y = np.concatenate(df_nist[\"is_aromatic\"].values * 1, dtype=\"float32\")\n", + "\n", " # ADD label to input features\n", " # self.X = [np.append(self.X[i], self.y[i]) for i in range(len(self))]\n", - " \n", + "\n", " def __len__(self):\n", " return len(self.X)\n", "\n", @@ -1799,14 +2027,21 @@ } ], "source": [ - "\n", - "\n", - "dummy_model = MLPEdgeClassifier(23/2.0, 0, 0)\n", + "dummy_model = MLPEdgeClassifier(23 / 2.0, 0, 0)\n", "training_data = AtomAromaticityData()\n", - "optimizer = torch.optim.SGD(dummy_model.parameters(), momentum = 0.9, lr=0.01)\n", - "loss_fn = torch.nn.BCEWithLogitsLoss() #Alternatively for sigmoid use torch.nn.BCELoss()\n", + "optimizer = torch.optim.SGD(dummy_model.parameters(), momentum=0.9, lr=0.01)\n", + "loss_fn = (\n", + " torch.nn.BCEWithLogitsLoss()\n", + ") # Alternatively for sigmoid use torch.nn.BCELoss()\n", "tracker = MetricTracker(Accuracy(num_classes=1))\n", - "train(dummy_model, DataLoader(training_data, batch_size=64), optimizer=optimizer, loss_fn=loss_fn, epochs=5, tracker=tracker)" + "train(\n", + " dummy_model,\n", + " DataLoader(training_data, batch_size=64),\n", + " optimizer=optimizer,\n", + " loss_fn=loss_fn,\n", + " epochs=5,\n", + " tracker=tracker,\n", + ")" ] }, { @@ -1825,13 +2060,26 @@ "class AtomAromaticityData(Dataset):\n", " def __init__(self) -> None:\n", "\n", - " self.X = df_nist[\"features\"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values#[:1000]\n", - " self.A = df_nist[\"Atilde\"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values\n", - " self.y = df_nist[\"is_aromatic\"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values*1\n", + " self.X = (\n", + " df_nist[\"features\"]\n", + " .apply(lambda x: torch.tensor(x, dtype=torch.float32))\n", + " .values\n", + " ) # [:1000]\n", + " self.A = (\n", + " df_nist[\"Atilde\"]\n", + " .apply(lambda x: torch.tensor(x, dtype=torch.float32))\n", + " .values\n", + " )\n", + " self.y = (\n", + " df_nist[\"is_aromatic\"]\n", + " .apply(lambda x: torch.tensor(x, dtype=torch.float32))\n", + " .values\n", + " * 1\n", + " )\n", "\n", " # ADD label to input features\n", " # self.X = [np.append(self.X[i], self.y[i]) for i in range(len(self))]\n", - " \n", + "\n", " def __len__(self):\n", " return len(self.X)\n", "\n", @@ -1842,7 +2090,7 @@ " return self.X[0].shape[1]\n", "\n", "\n", - "training_data = AtomAromaticityData()\n" + "training_data = AtomAromaticityData()" ] }, { @@ -1853,9 +2101,8 @@ "source": [ "def plot_acc(acc):\n", " fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4), sharey=False)\n", - " #plt.ylim(0.5, 1.0)\n", - " plt.plot(range(1, len(acc) + 1), acc, linewidth=2)\n", - "\n" + " # plt.ylim(0.5, 1.0)\n", + " plt.plot(range(1, len(acc) + 1), acc, linewidth=2)" ] }, { @@ -1886,13 +2133,24 @@ } ], "source": [ - "gcn_node_model = GCNNodeClassifier(training_data.num_features(), training_data.num_features(), 1, depth=0)\n", + "gcn_node_model = GCNNodeClassifier(\n", + " training_data.num_features(), training_data.num_features(), 1, depth=0\n", + ")\n", "\n", - "#optimizer = torch.optim.SGD(gcn_node_model.parameters(), momentum = 0.9, lr=0.01)\n", + "# optimizer = torch.optim.SGD(gcn_node_model.parameters(), momentum = 0.9, lr=0.01)\n", "optimizer = torch.optim.Adam(gcn_node_model.parameters(), lr=0.01)\n", - "loss_fn = torch.nn.BCEWithLogitsLoss() #Alternatively for sigmoid use torch.nn.BCELoss()\n", + "loss_fn = (\n", + " torch.nn.BCEWithLogitsLoss()\n", + ") # Alternatively for sigmoid use torch.nn.BCELoss()\n", "tracker = MetricTracker(Accuracy(num_classes=1))\n", - "acc = train_gnn(gcn_node_model, training_data, optimizer=optimizer, loss_fn=loss_fn, epochs=5, tracker=tracker)" + "acc = train_gnn(\n", + " gcn_node_model,\n", + " training_data,\n", + " optimizer=optimizer,\n", + " loss_fn=loss_fn,\n", + " epochs=5,\n", + " tracker=tracker,\n", + ")" ] }, { @@ -1923,13 +2181,24 @@ } ], "source": [ - "gcn_node_model = GCNNodeClassifier(training_data.num_features(), training_data.num_features(), 1, depth=1)\n", + "gcn_node_model = GCNNodeClassifier(\n", + " training_data.num_features(), training_data.num_features(), 1, depth=1\n", + ")\n", "\n", - "#optimizer = torch.optim.SGD(gcn_node_model.parameters(), momentum = 0.9, lr=0.01)\n", + "# optimizer = torch.optim.SGD(gcn_node_model.parameters(), momentum = 0.9, lr=0.01)\n", "optimizer = torch.optim.Adam(gcn_node_model.parameters(), lr=0.01)\n", - "loss_fn = torch.nn.BCEWithLogitsLoss() #Alternatively for sigmoid use torch.nn.BCELoss()\n", + "loss_fn = (\n", + " torch.nn.BCEWithLogitsLoss()\n", + ") # Alternatively for sigmoid use torch.nn.BCELoss()\n", "tracker = MetricTracker(Accuracy(num_classes=1))\n", - "acc = train_gnn(gcn_node_model, training_data, optimizer=optimizer, loss_fn=loss_fn, epochs=5, tracker=tracker)" + "acc = train_gnn(\n", + " gcn_node_model,\n", + " training_data,\n", + " optimizer=optimizer,\n", + " loss_fn=loss_fn,\n", + " epochs=5,\n", + " tracker=tracker,\n", + ")" ] }, { @@ -1964,13 +2233,22 @@ } ], "source": [ - "gcn_node_model = GCNNodeClassifier(2,2, 1, depth=5)\n", - "#for param in gcn_node_model.layers[0].parameters():\n", - "# print(param) \n", + "gcn_node_model = GCNNodeClassifier(2, 2, 1, depth=5)\n", + "# for param in gcn_node_model.layers[0].parameters():\n", + "# print(param)\n", "optimizer = torch.optim.Adam(gcn_node_model.parameters(), lr=0.01)\n", - "loss_fn = torch.nn.BCEWithLogitsLoss() #Alternatively for sigmoid use torch.nn.BCELoss()\n", + "loss_fn = (\n", + " torch.nn.BCEWithLogitsLoss()\n", + ") # Alternatively for sigmoid use torch.nn.BCELoss()\n", "tracker = MetricTracker(Accuracy(num_classes=1))\n", - "acc = train_gnn(gcn_node_model, training_data, optimizer=optimizer, loss_fn=loss_fn, epochs=5, tracker=tracker)" + "acc = train_gnn(\n", + " gcn_node_model,\n", + " training_data,\n", + " optimizer=optimizer,\n", + " loss_fn=loss_fn,\n", + " epochs=5,\n", + " tracker=tracker,\n", + ")" ] }, { @@ -2139,7 +2417,7 @@ ], "source": [ "for param in gcn_node_model.layers[0].parameters():\n", - " print(param) " + " print(param)" ] }, { @@ -2157,6 +2435,7 @@ "source": [ "import tensorflow as tf\n", "\n", + "\n", "# idea: H' = sigmoid(AHW) with adjacency matrix A, feature matrix H, and weight matrix W (linear transformation)\n", "def gnn_pool(features, A, transform, activation):\n", " HW = transform(features)\n", @@ -2164,6 +2443,7 @@ "\n", " return activation(AHW)\n", "\n", + "\n", "def gnn(features, A, self_transform, transform, activation):\n", " HW = transform(features)\n", " AHW = tf.matmul(A, HW)\n", @@ -2172,20 +2452,38 @@ " return activation(tf.add(HV, AHW))\n", "\n", "\n", - "\n", "num_elems = training_data.num_features()\n", "\n", - "layer_conv1, layer_self1 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", - "layer_conv2, layer_self2 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", - "layer_conv3, layer_self3 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", - "layer_conv4, layer_self4 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", - "layer_conv5, layer_self5 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", + "layer_conv1, layer_self1 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", + "layer_conv2, layer_self2 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", + "layer_conv3, layer_self3 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", + "layer_conv4, layer_self4 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", + "layer_conv5, layer_self5 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", "layer_final = tf.keras.layers.Dense(1)\n", "\n", + "\n", "def gnn_pooling_model(features, A):\n", - " hidden_features = gnn_pool(features, A, layer_conv1, activation=tf.nn.relu)\n", - " output_logits = layer_final(hidden_features) #= gnn(hidden_features, A, layer_2, activation=tf.identity)#tf.nn.sigmoid)#tf.identity)\n", - " return output_logits\n", + " hidden_features = gnn_pool(features, A, layer_conv1, activation=tf.nn.relu)\n", + " output_logits = layer_final(\n", + " hidden_features\n", + " ) # = gnn(hidden_features, A, layer_2, activation=tf.identity)#tf.nn.sigmoid)#tf.identity)\n", + " return output_logits\n", + "\n", "\n", "def gnn_model(features, A):\n", " hidden_features = gnn(features, A, layer_self1, layer_conv1, tf.nn.relu)\n", @@ -2196,6 +2494,7 @@ " output_logits = layer_final(hidden_features)\n", " return output_logits\n", "\n", + "\n", "def gnn_model_d1(features, A):\n", " hidden_features = gnn(features, A, layer_self1, layer_conv1, tf.nn.relu)\n", " output_logits = layer_final(hidden_features)\n", @@ -2255,59 +2554,81 @@ } ], "source": [ - "\n", "def validate_tf_model(data, model, y_label, tracker, verbose=False, **kwargs):\n", " losses, y_true, y_hat = [], [], []\n", " correct_mol = 0\n", " tracker.increment()\n", " for batch_id, (X, A, y) in enumerate(data):\n", - " \n", - " logits = model(X.detach().numpy(), A.detach().numpy(),)\n", - " #logits = model(d.features, tf.cast(d.Id, dtype=tf.float32), **model_kwargs)\n", - " #y_tensor = tf.cast(d[y_label], dtype=tf.float32)\n", - " loss = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(y, dtype=tf.float32), logits)\n", + " logits = model(\n", + " X.detach().numpy(),\n", + " A.detach().numpy(),\n", + " )\n", + " # logits = model(d.features, tf.cast(d.Id, dtype=tf.float32), **model_kwargs)\n", + " # y_tensor = tf.cast(d[y_label], dtype=tf.float32)\n", + " loss = tf.nn.sigmoid_cross_entropy_with_logits(\n", + " tf.cast(y, dtype=tf.float32), logits\n", + " )\n", "\n", " losses = np.append(losses, loss.numpy())\n", "\n", " tracker.update(torch.from_numpy(logits[0].numpy()), y[0].int())\n", " y_hat += [float(x) > 0 for x in logits]\n", " y_true += [int(x) for x in y]\n", - " if y_hat[-len(logits):] == y_true[-len(logits):]:\n", - " correct_mol+=1\n", + " if y_hat[-len(logits) :] == y_true[-len(logits) :]:\n", + " correct_mol += 1\n", "\n", " acc, mean_loss = sklearn.metrics.accuracy_score(y_true, y_hat), np.mean(losses)\n", - " if verbose: \n", - " print(\"Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f\" % (acc, mean_loss, correct_mol / len(data)))\n", + " if verbose:\n", + " print(\n", + " \"Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f\"\n", + " % (acc, mean_loss, correct_mol / len(data))\n", + " )\n", " print(f\"Torch MetricTracker Accuracy: {tracker.compute():.3f}\")\n", " return acc, mean_loss, correct_mol / len(data)\n", "\n", + "\n", "def train_tf_gnn(data, model):\n", " tracker = MetricTracker(Accuracy(num_classes=1, multiclass=False))\n", " optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n", " train_proportion = 0.8\n", - " train_size = int(len(data)*train_proportion)\n", - " training_data, validation_data = torch.utils.data.random_split(data, [train_size, len(data) - train_size], generator=torch.Generator().manual_seed(42))\n", - " dataloader_training, dataloader_val = DataLoader(training_data, batch_size=1), DataLoader(validation_data, batch_size=1)\n", + " train_size = int(len(data) * train_proportion)\n", + " training_data, validation_data = torch.utils.data.random_split(\n", + " data,\n", + " [train_size, len(data) - train_size],\n", + " generator=torch.Generator().manual_seed(42),\n", + " )\n", + " dataloader_training, dataloader_val = (\n", + " DataLoader(training_data, batch_size=1),\n", + " DataLoader(validation_data, batch_size=1),\n", + " )\n", " for epoch in range(1, 6):\n", " print(\"Epoch %s\" % epoch)\n", " training_loss = 0\n", " training_loss_torch = 0\n", " tracker.increment()\n", " for batch_id, (X, A, y) in enumerate(dataloader_training):\n", - " #logits = gnn_pooling_model(d.features, tf.cast(d.Atilde, dtype=tf.float32))\n", + " # logits = gnn_pooling_model(d.features, tf.cast(d.Atilde, dtype=tf.float32))\n", " with tf.GradientTape() as t:\n", - " logits = model(X.detach().numpy(), A.detach().numpy()) #tf.cast(A, dtype=tf.float32))\n", - " loss = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(y, dtype=tf.float32), logits)\n", + " logits = model(\n", + " X.detach().numpy(), A.detach().numpy()\n", + " ) # tf.cast(A, dtype=tf.float32))\n", + " loss = tf.nn.sigmoid_cross_entropy_with_logits(\n", + " tf.cast(y, dtype=tf.float32), logits\n", + " )\n", " print(torch.from_numpy(logits.numpy()))\n", " print(y)\n", - " torch_loss = torch.nn.BCEWithLogitsLoss(torch.from_numpy(logits[0].numpy()), y[0])\n", - " \n", - " #print(torch.tensor((logits[0]>0).numpy()).int())\n", - " #print(y[0].int())\n", - " #print('p: ', logits)\n", - " #print('y: ', y)\n", - " #print(errroror)\n", - " tracker.update(torch.from_numpy(logits[0].numpy()), y[0].int())#torch.tensor((logits[0]>0).numpy()).int(), y[0].int())\n", + " torch_loss = torch.nn.BCEWithLogitsLoss(\n", + " torch.from_numpy(logits[0].numpy()), y[0]\n", + " )\n", + "\n", + " # print(torch.tensor((logits[0]>0).numpy()).int())\n", + " # print(y[0].int())\n", + " # print('p: ', logits)\n", + " # print('y: ', y)\n", + " # print(errroror)\n", + " tracker.update(\n", + " torch.from_numpy(logits[0].numpy()), y[0].int()\n", + " ) # torch.tensor((logits[0]>0).numpy()).int(), y[0].int())\n", "\n", " variables = t.watched_variables()\n", " gradients = t.gradient(loss, variables)\n", @@ -2316,25 +2637,37 @@ " training_loss_torch += torch.item()\n", "\n", " if batch_id % 100 == 0:\n", - " print(f' Avg. training loss {training_loss / (batch_id + 1) :>.4f} (torch loss: {training_loss_torch / (batch_id + 1)}', end='\\r')\n", + " print(\n", + " f\" Avg. training loss {training_loss / (batch_id + 1):>.4f} (torch loss: {training_loss_torch / (batch_id + 1)}\",\n", + " end=\"\\r\",\n", + " )\n", "\n", " # On epoch end: Evaluation\n", " accuracy = tracker.compute()\n", " acc.append(accuracy)\n", " training_loss /= len(dataloader_training)\n", - " print(f' Training Accuracy: {accuracy:>.3f} (Loss: {training_loss:>.3f})', flush=True)\n", - " #validate_gnn(gnn_model, dataloader_val, loss_fn, tracker)\n", + " print(\n", + " f\" Training Accuracy: {accuracy:>.3f} (Loss: {training_loss:>.3f})\",\n", + " flush=True,\n", + " )\n", + " # validate_gnn(gnn_model, dataloader_val, loss_fn, tracker)\n", "\n", " # Validate loss/acc\n", - " validate_tf_model(data=validation_data, model=model, y_label=\"is_aromatic\", verbose=True, tracker=tracker)\n", + " validate_tf_model(\n", + " data=validation_data,\n", + " model=model,\n", + " y_label=\"is_aromatic\",\n", + " verbose=True,\n", + " tracker=tracker,\n", + " )\n", "\n", " return\n", "\n", "\n", - "#df_train, df_test = train_test_split(df, test_size=0.5)\n", - "#print(df_train.shape)\n", + "# df_train, df_test = train_test_split(df, test_size=0.5)\n", + "# print(df_train.shape)\n", "\n", - "train_tf_gnn(training_data, gnn_model)\n" + "train_tf_gnn(training_data, gnn_model)" ] }, { @@ -2393,6 +2726,7 @@ "def accuracy_logits(y, yhat):\n", " return sklearn.metrics.accuracy_score(y, [x > 0 for x in yhat])\n", "\n", + "\n", "def validate_tf_model_old(data, model, y_label, verbose=False, **kwargs):\n", " losses, y_true, y_hat = [], [], []\n", " correct_mol = 0\n", @@ -2401,27 +2735,36 @@ " model_kwargs = {}\n", " for key, value in kwargs.items():\n", " model_kwargs[key] = tf.cast(d[value], dtype=tf.float32)\n", - " logits = model(d.features, tf.cast(d.A / d.deg, dtype=tf.float32), **model_kwargs)\n", - " #logits = model(d.features, tf.cast(d.Id, dtype=tf.float32), **model_kwargs)\n", + " logits = model(\n", + " d.features, tf.cast(d.A / d.deg, dtype=tf.float32), **model_kwargs\n", + " )\n", + " # logits = model(d.features, tf.cast(d.Id, dtype=tf.float32), **model_kwargs)\n", " y_tensor = tf.cast(d[y_label], dtype=tf.float32)\n", " loss = tf.nn.sigmoid_cross_entropy_with_logits(y_tensor, logits)\n", "\n", " losses = np.append(losses, loss.numpy())\n", " y_hat += [float(x) > 0 for x in logits]\n", " y_true += [int(x) for x in y_tensor]\n", - " if y_hat[-len(logits):] == y_true[-len(logits):]:\n", - " correct_mol+=1\n", + " if y_hat[-len(logits) :] == y_true[-len(logits) :]:\n", + " correct_mol += 1\n", "\n", " acc, mean_loss = sklearn.metrics.accuracy_score(y_true, y_hat), np.mean(losses)\n", - " if verbose: print(\"Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f\" % (acc, mean_loss, correct_mol / data.shape[0]))\n", + " if verbose:\n", + " print(\n", + " \"Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f\"\n", + " % (acc, mean_loss, correct_mol / data.shape[0])\n", + " )\n", "\n", " return acc, mean_loss, correct_mol / data.shape[0]\n", "\n", + "\n", "def train_tf_gnn_old(data):\n", "\n", " optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n", - " #training_data, validation_data = torch.utils.data.random_split(training_data, [train_size, len(training_data) - train_size], generator=torch.Generator().manual_seed(42))\n", - " data_train, data_val = train_test_split(data, test_size=0.2)#DataLoader(training_data, batch_size=1), DataLoader(validation_data, batch_size=1)\n", + " # training_data, validation_data = torch.utils.data.random_split(training_data, [train_size, len(training_data) - train_size], generator=torch.Generator().manual_seed(42))\n", + " data_train, data_val = train_test_split(\n", + " data, test_size=0.2\n", + " ) # DataLoader(training_data, batch_size=1), DataLoader(validation_data, batch_size=1)\n", "\n", " tracker = MetricTracker(Accuracy(num_classes=1, multiclass=False))\n", "\n", @@ -2430,9 +2773,11 @@ " for i in data_train.index:\n", " with tf.GradientTape() as t:\n", " d = data_train.loc[i]\n", - " #logits = gnn_pooling_model(d.features, tf.cast(d.Atilde, dtype=tf.float32))\n", + " # logits = gnn_pooling_model(d.features, tf.cast(d.Atilde, dtype=tf.float32))\n", " logits = gnn_model(d.features, tf.cast(d.A / d.deg, dtype=tf.float32))\n", - " loss = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(d.is_aromatic, dtype=tf.float32), logits)\n", + " loss = tf.nn.sigmoid_cross_entropy_with_logits(\n", + " tf.cast(d.is_aromatic, dtype=tf.float32), logits\n", + " )\n", " print(logits, d.is_aromatic)\n", " tracker.update(logits, d.is_aromatic)\n", " variables = t.watched_variables()\n", @@ -2441,7 +2786,9 @@ "\n", " # Validate loss/acc\n", " print(\"Epoch %s\" % epoch)\n", - " validate_tf_model_old(data=data_val, model=gnn_model, y_label=\"is_aromatic\", verbose=True)\n", + " validate_tf_model_old(\n", + " data=data_val, model=gnn_model, y_label=\"is_aromatic\", verbose=True\n", + " )\n", "\n", " return\n", "\n", @@ -2592,12 +2939,23 @@ ], "source": [ "training_data = AtomAromaticityData()\n", - "gcn_node_model = GCNNodeClassifier(training_data.num_features(), training_data.num_features(), 1, depth=5)\n", + "gcn_node_model = GCNNodeClassifier(\n", + " training_data.num_features(), training_data.num_features(), 1, depth=5\n", + ")\n", "\n", "optimizer = torch.optim.AdamW(gcn_node_model.parameters(), lr=0.01)\n", - "loss_fn = torch.nn.BCEWithLogitsLoss() #Alternatively for sigmoid use torch.nn.BCELoss()\n", + "loss_fn = (\n", + " torch.nn.BCEWithLogitsLoss()\n", + ") # Alternatively for sigmoid use torch.nn.BCELoss()\n", "tracker = MetricTracker(Accuracy(num_classes=1))\n", - "train_gnn(gcn_node_model, DataLoader(training_data, batch_size=1), optimizer=optimizer, loss_fn=loss_fn, epochs=5, tracker=tracker)" + "train_gnn(\n", + " gcn_node_model,\n", + " DataLoader(training_data, batch_size=1),\n", + " optimizer=optimizer,\n", + " loss_fn=loss_fn,\n", + " epochs=5,\n", + " tracker=tracker,\n", + ")" ] }, { @@ -2685,16 +3043,16 @@ ], "source": [ "print(gcn_node_model)\n", - "#X, A, y = a.__getitem__(0)\n", + "# X, A, y = a.__getitem__(0)\n", "\n", - "#print(X)\n", - "for _, (X,A,y) in enumerate(aa):\n", + "# print(X)\n", + "for _, (X, A, y) in enumerate(aa):\n", " print(y[0])\n", " break\n", "with torch.no_grad():\n", " y_pred = gcn_node_model(X, A)\n", " print(y, torch.sigmoid(y_pred))\n", - " #print(y, y_pred[0] > 0)\n", + " # print(y, y_pred[0] > 0)\n", " tracker.update(y[0], y_pred[0] > 0)\n", "\n", "tracker.compute()" @@ -2727,30 +3085,36 @@ "class BondAromaticityData(Dataset):\n", " def __init__(self) -> None:\n", "\n", - " edge_connected_node_features = df_nist.apply(lambda x: self.concatenate_node_features(x), axis = 1)\n", - " self.X = np.concatenate(edge_connected_node_features.values, dtype='float32')\n", - " self.y = np.concatenate(df_nist[\"edges_is_aromatic\"].values*1, dtype='float32')\n", - " \n", + " edge_connected_node_features = df_nist.apply(\n", + " lambda x: self.concatenate_node_features(x), axis=1\n", + " )\n", + " self.X = np.concatenate(edge_connected_node_features.values, dtype=\"float32\")\n", + " self.y = np.concatenate(\n", + " df_nist[\"edges_is_aromatic\"].values * 1, dtype=\"float32\"\n", + " )\n", + "\n", " # ADD label to input features\n", " # self.X = [np.append(self.X[i], self.y[i]) for i in range(len(self))]\n", - " \n", + "\n", " def concatenate_node_features(self, d):\n", - " return np.concatenate([d['AL']@d['features'], d['AR']@d['features']], axis=1)\n", - " \n", + " return np.concatenate(\n", + " [d[\"AL\"] @ d[\"features\"], d[\"AR\"] @ d[\"features\"]], axis=1\n", + " )\n", "\n", " def __len__(self):\n", " return len(self.X)\n", "\n", - " def __getitem__(self, idx): \n", + " def __getitem__(self, idx):\n", " return [self.X[idx], self.y[idx]]\n", "\n", " def num_features(self):\n", " return self.X.shape[1]\n", "\n", - "#test_x = df_nist.iloc[0]\n", "\n", - "#X_x = np.concatenate([test_x['AL']@test_x['features'], test_x['AR']@test_x['features']], axis=1)\n", - "#X_x.shape\n" + "# test_x = df_nist.iloc[0]\n", + "\n", + "# X_x = np.concatenate([test_x['AL']@test_x['features'], test_x['AR']@test_x['features']], axis=1)\n", + "# X_x.shape" ] }, { @@ -2787,10 +3151,19 @@ "source": [ "model = MLPEdgeClassifier(23, 0, 0)\n", "training_data = BondAromaticityData()\n", - "optimizer = torch.optim.SGD(model.parameters(), momentum = 0.9, lr=0.01)\n", - "loss_fn = torch.nn.BCEWithLogitsLoss() #Alternatively for sigmoid use torch.nn.BCELoss()\n", + "optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.01)\n", + "loss_fn = (\n", + " torch.nn.BCEWithLogitsLoss()\n", + ") # Alternatively for sigmoid use torch.nn.BCELoss()\n", "tracker = MetricTracker(Accuracy(num_classes=1))\n", - "train(model, DataLoader(training_data, batch_size=64), optimizer=optimizer, loss_fn=loss_fn, epochs=3, tracker=tracker)" + "train(\n", + " model,\n", + " DataLoader(training_data, batch_size=64),\n", + " optimizer=optimizer,\n", + " loss_fn=loss_fn,\n", + " epochs=3,\n", + " tracker=tracker,\n", + ")" ] }, { @@ -2822,10 +3195,11 @@ "outputs": [], "source": [ "from modules.MOL.FragmentationTree import FragmentationTree\n", + "\n", "importlib.reload(modules.MOL.FragmentationTree)\n", "\n", "FT = FragmentationTree(x[\"MOL\"])\n", - "#frag.build_fragmentation_tree_by_rotatable_bond_breaks()\n", + "# frag.build_fragmentation_tree_by_rotatable_bond_breaks()\n", "FT.build_fragmentation_tree(x[\"MOL\"], x.edges_idx, depth=1)\n", "\n", "x_matches = FT.match_peak_list(df_nist.loc[EXAMPLE_ID][\"peaks\"][\"mz\"])\n", @@ -2841,29 +3215,37 @@ "metadata": {}, "outputs": [], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", "\n", "img = Chem.Draw.MolToImage(x[\"MOL\"], ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", - "axs[0].set_title(x[\"Name\"]+ \" structure:\\n\" + Chem.MolToSmiles(x[\"MOL\"]))\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x[\"MOL\"]))\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n", - "axs[1].text(100, 0.20, 'GeeksforGeeks', style ='italic',\n", - " fontsize = 30, color =\"green\")\n", + "axs[1].text(100, 0.20, \"GeeksforGeeks\", style=\"italic\", fontsize=30, color=\"green\")\n", "\n", "\n", "print(\"m/z\", x[\"peaks\"][\"mz\"])\n", "print(\"Int\", x[\"peaks\"][\"intensity\"])\n", "print(\"\\nCreate Fragmentation Tree (depth = 1)\\n\")\n", "print(FT)\n", - "#print(t.size(level=0), t.size(level=1))# t.size(level=2))\n", + "# print(t.size(level=0), t.size(level=1))# t.size(level=2))\n", "\n", "mols = [x[\"MOL\"], FT.get_fragment(3), FT.get_fragment(10), FT.get_fragment(7)]\n", "\n", - "Chem.Draw.MolsToGridImage(mols, molsPerRow=4, useSVG=True, legends=[f' mol ({Chem.Descriptors.ExactMolWt(m)})' for m in mols])" + "Chem.Draw.MolsToGridImage(\n", + " mols,\n", + " molsPerRow=4,\n", + " useSVG=True,\n", + " legends=[f\" mol ({Chem.Descriptors.ExactMolWt(m)})\" for m in mols],\n", + ")" ] }, { @@ -4939,7 +5321,12 @@ "# TODO ::::::::::::::::::::::::::::::::::::: CONTINUE ::::::::::::::::::::::::::::::::::::::::::\n", "#\n", "\n", - "Chem.Draw.MolsToGridImage([df_nist.iloc[i][\"MOL\"] for i in range(50)], molsPerRow=4, useSVG=True, legends=[str(i) for i in range(50)]) #" + "Chem.Draw.MolsToGridImage(\n", + " [df_nist.iloc[i][\"MOL\"] for i in range(50)],\n", + " molsPerRow=4,\n", + " useSVG=True,\n", + " legends=[str(i) for i in range(50)],\n", + ") #" ] }, { @@ -5102,10 +5489,9 @@ } ], "source": [ - "\n", "n = df_nist.iloc[40].Name\n", "print(n)\n", - "df_nist[df_nist.Name==n][[\"CE\", \"Num peaks\", \"peaks\"]]\n" + "df_nist[df_nist.Name == n][[\"CE\", \"Num peaks\", \"peaks\"]]" ] }, { @@ -5117,38 +5503,50 @@ "EXAMPLE_2_ID = 495920\n", "\n", "\n", - "\n", "x = df_nist.loc[EXAMPLE_2_ID]\n", "\n", - "#print(x)\n", + "# print(x)\n", "\n", "FT = FragmentationTree()\n", "FT.build_fragmentation_tree_by_single_edge_breaks(x[\"MOL\"], x.edges_idx, depth=1)\n", "\n", - "#print(Chem.Descriptors.ExactMolWt(x[\"MOL\"]))\n", - "#print(Chem.Descriptors.ExactMolWt(t.get_node(6).data))\n", + "# print(Chem.Descriptors.ExactMolWt(x[\"MOL\"]))\n", + "# print(Chem.Descriptors.ExactMolWt(t.get_node(6).data))\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", "\n", "img = Chem.Draw.MolToImage(x[\"MOL\"], ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", - "axs[0].set_title(x[\"Name\"]+ \" structure:\\n\" + Chem.MolToSmiles(x[\"MOL\"]))\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x[\"MOL\"]))\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n", - "axs[1].text(100, 0.20, 'GeeksforGeeks', style ='italic',\n", - " fontsize = 30, color =\"green\")\n", + "axs[1].text(100, 0.20, \"GeeksforGeeks\", style=\"italic\", fontsize=30, color=\"green\")\n", "\n", "\n", "print(\"m/z\", x[\"peaks\"][\"mz\"])\n", "print(\"Int\", x[\"peaks\"][\"intensity\"])\n", "print(\"\\nCreate Fragmentation Tree (depth = 1)\\n\")\n", "t.show(idhidden=False)\n", - "#print(t.size(level=0), t.size(level=1))# t.size(level=2))\n", + "# print(t.size(level=0), t.size(level=1))# t.size(level=2))\n", "\n", - "Chem.Draw.MolsToGridImage([x[\"MOL\"], t.get_node(1).data, t.get_node(4).data, t.get_node(5).data], molsPerRow=4, useSVG=True, legends=[\"intact\", f\"frag ({t.get_node(1).tag:.03f})\", f\"frag ({t.get_node(4).tag:.03f})\", f\"frag ({t.get_node(5).tag:.03f})\"])" + "Chem.Draw.MolsToGridImage(\n", + " [x[\"MOL\"], t.get_node(1).data, t.get_node(4).data, t.get_node(5).data],\n", + " molsPerRow=4,\n", + " useSVG=True,\n", + " legends=[\n", + " \"intact\",\n", + " f\"frag ({t.get_node(1).tag:.03f})\",\n", + " f\"frag ({t.get_node(4).tag:.03f})\",\n", + " f\"frag ({t.get_node(5).tag:.03f})\",\n", + " ],\n", + ")" ] }, { @@ -5158,8 +5556,9 @@ "outputs": [], "source": [ "import modules.MS.ms_utility as msutil\n", + "\n", "importlib.reload(modules.MS.ms_utility)\n", - "importlib.reload(modules.MOL.FragmentationTree)\n" + "importlib.reload(modules.MOL.FragmentationTree)" ] }, { @@ -5182,7 +5581,13 @@ ], "source": [ "for off in offsets:\n", - " print(off, \":\", np.mean(D[str(off)][\"peaks\"]), np.mean(D[str(off)][\"unique\"]), np.mean(D[str(off)][\"percentage\"]))" + " print(\n", + " off,\n", + " \":\",\n", + " np.mean(D[str(off)][\"peaks\"]),\n", + " np.mean(D[str(off)][\"unique\"]),\n", + " np.mean(D[str(off)][\"percentage\"]),\n", + " )" ] }, { @@ -5232,23 +5637,23 @@ } ], "source": [ - "#sns.boxplot(y=D[str(off)][\"peaks\"], palette=color_palette)\n", - "#plt.show()\n", + "# sns.boxplot(y=D[str(off)][\"peaks\"], palette=color_palette)\n", + "# plt.show()\n", "for off in offsets:\n", " fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", " fig.suptitle(f\"Identified peaks with fragment offset: {str(off)}\")\n", - " #plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", + " # plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", " sns.boxplot(ax=axs[0], y=D[str(off)][\"peaks\"], color=color_palette[0])\n", " axs[0].set_ylim(-0.5, 10)\n", " axs[0].set_ylabel(\"peaks identified\")\n", " sns.boxplot(ax=axs[1], y=D[str(off)][\"unique\"], color=color_palette[1])\n", - " axs[1].set_ylim(-0.5,10)\n", + " axs[1].set_ylim(-0.5, 10)\n", " axs[1].set_ylabel(\"peaks uniquely identified\")\n", " sns.boxplot(ax=axs[2], y=D[str(off)][\"percentage\"], color=color_palette[2])\n", - " axs[2].set_ylim(-0.05,1.0)\n", + " axs[2].set_ylim(-0.05, 1.0)\n", " axs[2].set_ylabel(\"peak percentags (by number)\")\n", - " plt.show()\n" + " plt.show()" ] }, { @@ -5258,9 +5663,14 @@ "outputs": [], "source": [ "print(zigzagerrorhack)\n", - "frame = PandasTools.LoadSDF(library_directory + library_name + \".SDF\",smilesName='SMILES',molColName='Molecule', includeFingerprints=True)\n", + "frame = PandasTools.LoadSDF(\n", + " library_directory + library_name + \".SDF\",\n", + " smilesName=\"SMILES\",\n", + " molColName=\"Molecule\",\n", + " includeFingerprints=True,\n", + ")\n", "\n", - "frame.info " + "frame.info" ] }, { @@ -5272,12 +5682,12 @@ "# structure library\n", "structure_supplier = Chem.SDMolSupplier(library_directory + library_name + \".SDF\")\n", "print(structure_supplier)\n", - "nist_mols = [structure_supplier[i] for i in range(0,100)] #TODO use all\n", + "nist_mols = [structure_supplier[i] for i in range(0, 100)] # TODO use all\n", "df_nist = pd.DataFrame()\n", - "df_nist['mol'] = nist_mols\n", + "df_nist[\"mol\"] = nist_mols\n", "df_nist.dropna(inplace=True)\n", "\n", - "df_nist['smiles'] = df_nist['mol'].apply(lambda x: Chem.MolToSmiles(x))\n", + "df_nist[\"smiles\"] = df_nist[\"mol\"].apply(lambda x: Chem.MolToSmiles(x))\n", "\n", "df_nist.head()" ] @@ -5292,9 +5702,7 @@ }, "outputs": [], "source": [ - "\n", - "\n", - "'''\n", + "\"\"\"\n", "\n", "hmdb_supplier = Chem.SDMolSupplier(f'{home}/data/metabolites/HMDB/structures.sdf')\n", "hmdb_mols = [hmdb_supplier[i] for i in range(0,100)] #TODO use all\n", @@ -5305,7 +5713,7 @@ "df['smiles'] = df['mol'].apply(lambda x: Chem.MolToSmiles(x))\n", "\n", "df.head()\n", - "'''" + "\"\"\"" ] }, { @@ -5331,44 +5739,56 @@ "source": [ "import networkx as nx\n", "\n", - "color_map = {'C': 'gray',\n", - " 'O': 'red',\n", - " 'N': 'blue'}\n", + "color_map = {\"C\": \"gray\", \"O\": \"red\", \"N\": \"blue\"}\n", + "\n", "\n", "def mol_to_nx(mol):\n", " G = nx.Graph()\n", "\n", " for atom in mol.GetAtoms():\n", - " color = color_map[atom.GetSymbol()] if atom.GetSymbol() in color_map.keys() else 'black'\n", - " G.add_node(atom.GetIdx(),\n", - " atomic_num=atom.GetAtomicNum(),\n", - " is_aromatic=atom.GetIsAromatic(),\n", - " atom_symbol=atom.GetSymbol(),\n", - " color=color,\n", - " atom=atom)\n", + " color = (\n", + " color_map[atom.GetSymbol()]\n", + " if atom.GetSymbol() in color_map.keys()\n", + " else \"black\"\n", + " )\n", + " G.add_node(\n", + " atom.GetIdx(),\n", + " atomic_num=atom.GetAtomicNum(),\n", + " is_aromatic=atom.GetIsAromatic(),\n", + " atom_symbol=atom.GetSymbol(),\n", + " color=color,\n", + " atom=atom,\n", + " )\n", "\n", " for bond in mol.GetBonds():\n", - " G.add_edge(bond.GetBeginAtomIdx(),\n", - " bond.GetEndAtomIdx(),\n", - " bond_type=bond.GetBondType(),\n", - " bond=bond)\n", - " #is_double=bond.GEt)\n", + " G.add_edge(\n", + " bond.GetBeginAtomIdx(),\n", + " bond.GetEndAtomIdx(),\n", + " bond_type=bond.GetBondType(),\n", + " bond=bond,\n", + " )\n", + " # is_double=bond.GEt)\n", "\n", " return G\n", "\n", + "\n", "def draw_graph(G, edge_labels=False):\n", " pos = nx.spring_layout(G)\n", - " nx.draw(G,pos=pos,\n", - " labels=nx.get_node_attributes(G, 'atom_symbol'),\n", - " with_labels = True,\n", - " node_color=list(nx.get_node_attributes(G, 'color').values()),\n", - " node_size=800)\n", + " nx.draw(\n", + " G,\n", + " pos=pos,\n", + " labels=nx.get_node_attributes(G, \"atom_symbol\"),\n", + " with_labels=True,\n", + " node_color=list(nx.get_node_attributes(G, \"color\").values()),\n", + " node_size=800,\n", + " )\n", " if edge_labels:\n", " nx.draw_networkx_edge_labels(\n", - " G, pos,\n", - " edge_labels=dict([((n1, n2), f'({n1}, {n2})')\n", - " for n1, n2 in G.edges]),\n", - " font_color='red')\n", + " G,\n", + " pos,\n", + " edge_labels=dict([((n1, n2), f\"({n1}, {n2})\") for n1, n2 in G.edges]),\n", + " font_color=\"red\",\n", + " )\n", " plt.show()\n", "\n", "\n", @@ -5376,14 +5796,35 @@ " def __init__(self):\n", " self.encoded_dim = 0\n", " self.sets = {\n", - " \"symbol\": {\"B\", \"Br\", \"C\", \"Ca\", \"Cl\", \"F\", \"H\", \"I\", \"N\", \"Na\", \"O\", \"P\", \"S\"},\n", - " \"num_hydrogen\": {0, 1, 2, 3, 4, 5, 6, 7, 8}\n", + " \"symbol\": {\n", + " \"B\",\n", + " \"Br\",\n", + " \"C\",\n", + " \"Ca\",\n", + " \"Cl\",\n", + " \"F\",\n", + " \"H\",\n", + " \"I\",\n", + " \"N\",\n", + " \"Na\",\n", + " \"O\",\n", + " \"P\",\n", + " \"S\",\n", + " },\n", + " \"num_hydrogen\": {0, 1, 2, 3, 4, 5, 6, 7, 8},\n", " }\n", - " self.reduced_features = [\"symbol\"] # Where unknown variables (not in the set) might occur, these will get a combined bit in the encoded vector\n", + " self.reduced_features = [\n", + " \"symbol\"\n", + " ] # Where unknown variables (not in the set) might occur, these will get a combined bit in the encoded vector\n", " self.one_hot_mapper = {}\n", " for feature in self.sets.keys():\n", " variables = self.sets[feature]\n", - " self.one_hot_mapper[feature] = dict(zip(variables, range(self.encoded_dim, len(variables) + self.encoded_dim)))\n", + " self.one_hot_mapper[feature] = dict(\n", + " zip(\n", + " variables,\n", + " range(self.encoded_dim, len(variables) + self.encoded_dim),\n", + " )\n", + " )\n", " self.encoded_dim += len(variables)\n", " if feature in self.reduced_features:\n", " self.encoded_dim += 1\n", @@ -5392,17 +5833,22 @@ " feature_matrix = np.zeros(shape=(G.number_of_nodes(), self.encoded_dim))\n", "\n", " for i in range(G.number_of_nodes()):\n", - " atom = G.nodes()[i]['atom']\n", + " atom = G.nodes()[i][\"atom\"]\n", "\n", - " if not atom.GetSymbol() in self.sets['symbol']:\n", - " feature_matrix[i][self.one_hot_mapper['symbol'][list(self.sets['symbol'])[-1]] + 1] = 1.0\n", + " if not atom.GetSymbol() in self.sets[\"symbol\"]:\n", + " feature_matrix[i][\n", + " self.one_hot_mapper[\"symbol\"][list(self.sets[\"symbol\"])[-1]] + 1\n", + " ] = 1.0\n", " else:\n", - " feature_matrix[i][self.one_hot_mapper['symbol'][atom.GetSymbol()]] = 1.0\n", + " feature_matrix[i][self.one_hot_mapper[\"symbol\"][atom.GetSymbol()]] = 1.0\n", "\n", - " feature_matrix[i][self.one_hot_mapper['num_hydrogen'][atom.GetTotalNumHs()]] = 1.0\n", + " feature_matrix[i][\n", + " self.one_hot_mapper[\"num_hydrogen\"][atom.GetTotalNumHs()]\n", + " ] = 1.0\n", "\n", " return feature_matrix\n", "\n", + "\n", "node_encoder = FeatureEncoder()" ] }, @@ -5428,25 +5874,45 @@ "outputs": [], "source": [ "num_elems = 12\n", + "\n", + "\n", "def add_dataframe_features(df):\n", - " df['graph'] = df['mol'].apply(mol_to_nx)\n", - " df['features'] = df['graph'].apply(lambda x: node_encoder.encode(x))\n", - " df['Xsymbol'] = df['graph'].apply(lambda x: [x.nodes[atom]['atom_symbol'] for atom in x.nodes()])\n", - " df['Xi'] = df['graph'].apply(lambda x: [min(x.nodes[atom]['atomic_num'], num_elems - 1) for atom in x.nodes()])\n", - " df['X'] = df['Xi'].apply(lambda x: to_categorical(x, num_classes=num_elems))\n", - " df['A'] = df['graph'].apply(nx.convert_matrix.to_numpy_matrix)\n", - " df['Atilde'] = df['A'].apply(lambda x: x + np.eye(N=x.shape[0]))\n", - " df['Id'] = df['A'].apply(lambda x: np.eye(N=x.shape[0]))\n", - " df['deg'] = df['A'].apply(lambda x: tf.transpose([tf.clip_by_value(tf.reduce_sum(x, axis=-1), 0.0001, 1000.0)]))\n", - " df['isAromatic'] = df['graph'].apply(lambda x: np.array([[x.nodes[atom]['is_aromatic'] for atom in x.nodes()]]).T)\n", + " df[\"graph\"] = df[\"mol\"].apply(mol_to_nx)\n", + " df[\"features\"] = df[\"graph\"].apply(lambda x: node_encoder.encode(x))\n", + " df[\"Xsymbol\"] = df[\"graph\"].apply(\n", + " lambda x: [x.nodes[atom][\"atom_symbol\"] for atom in x.nodes()]\n", + " )\n", + " df[\"Xi\"] = df[\"graph\"].apply(\n", + " lambda x: [\n", + " min(x.nodes[atom][\"atomic_num\"], num_elems - 1) for atom in x.nodes()\n", + " ]\n", + " )\n", + " df[\"X\"] = df[\"Xi\"].apply(lambda x: to_categorical(x, num_classes=num_elems))\n", + " df[\"A\"] = df[\"graph\"].apply(nx.convert_matrix.to_numpy_matrix)\n", + " df[\"Atilde\"] = df[\"A\"].apply(lambda x: x + np.eye(N=x.shape[0]))\n", + " df[\"Id\"] = df[\"A\"].apply(lambda x: np.eye(N=x.shape[0]))\n", + " df[\"deg\"] = df[\"A\"].apply(\n", + " lambda x: tf.transpose(\n", + " [tf.clip_by_value(tf.reduce_sum(x, axis=-1), 0.0001, 1000.0)]\n", + " )\n", + " )\n", + " df[\"isAromatic\"] = df[\"graph\"].apply(\n", + " lambda x: np.array([[x.nodes[atom][\"is_aromatic\"] for atom in x.nodes()]]).T\n", + " )\n", "\n", " # Extras\n", - " df['isN'] = df['graph'].apply(lambda x: np.array([[int(x.nodes[atom]['atom_symbol'] == 'N') for atom in x.nodes()]]))\n", - " df['isN_in_radius1'] = [df.loc[i, 'Atilde'] * df.loc[i,'isN'].T for i in df.index]\n", - " df['isN_in_radius1'] = df['isN_in_radius1'].apply(lambda x: x.clip(0, 1))\n", - " df['isN_neighboring'] = [df.loc[i, 'A'] * df.loc[i,'isN'].T for i in df.index]\n", - " df['isN_neighboring'] = df['isN_neighboring'].apply(lambda x: x.clip(0, 1))\n", + " df[\"isN\"] = df[\"graph\"].apply(\n", + " lambda x: np.array(\n", + " [[int(x.nodes[atom][\"atom_symbol\"] == \"N\") for atom in x.nodes()]]\n", + " )\n", + " )\n", + " df[\"isN_in_radius1\"] = [df.loc[i, \"Atilde\"] * df.loc[i, \"isN\"].T for i in df.index]\n", + " df[\"isN_in_radius1\"] = df[\"isN_in_radius1\"].apply(lambda x: x.clip(0, 1))\n", + " df[\"isN_neighboring\"] = [df.loc[i, \"A\"] * df.loc[i, \"isN\"].T for i in df.index]\n", + " df[\"isN_neighboring\"] = df[\"isN_neighboring\"].apply(lambda x: x.clip(0, 1))\n", " return df\n", + "\n", + "\n", "df_nist = add_dataframe_features(df_nist)" ] }, @@ -5486,6 +5952,7 @@ "\n", " return activation(AHW)\n", "\n", + "\n", "def gnn(features, A, self_transform, transform, activation):\n", " HW = transform(features)\n", " AHW = tf.matmul(A, HW)\n", @@ -5515,17 +5982,36 @@ }, "outputs": [], "source": [ - "layer_conv1, layer_self1 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", - "layer_conv2, layer_self2 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", - "layer_conv3, layer_self3 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", - "layer_conv4, layer_self4 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", - "layer_conv5, layer_self5 = tf.keras.layers.Dense(num_elems), tf.keras.layers.Dense(num_elems)\n", + "layer_conv1, layer_self1 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", + "layer_conv2, layer_self2 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", + "layer_conv3, layer_self3 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", + "layer_conv4, layer_self4 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", + "layer_conv5, layer_self5 = (\n", + " tf.keras.layers.Dense(num_elems),\n", + " tf.keras.layers.Dense(num_elems),\n", + ")\n", "layer_final = tf.keras.layers.Dense(1)\n", "\n", + "\n", "def gnn_pooling_model(features, A):\n", - " hidden_features = gnn_pool(features, A, layer_conv1, activation=tf.nn.relu)\n", - " output_logits = layer_final(hidden_features) #= gnn(hidden_features, A, layer_2, activation=tf.identity)#tf.nn.sigmoid)#tf.identity)\n", - " return output_logits\n", + " hidden_features = gnn_pool(features, A, layer_conv1, activation=tf.nn.relu)\n", + " output_logits = layer_final(\n", + " hidden_features\n", + " ) # = gnn(hidden_features, A, layer_2, activation=tf.identity)#tf.nn.sigmoid)#tf.identity)\n", + " return output_logits\n", + "\n", "\n", "def gnn_model(features, A):\n", " hidden_features = gnn(features, A, layer_self1, layer_conv1, tf.nn.relu)\n", @@ -5550,6 +6036,7 @@ "def accuracy_logits(y, yhat):\n", " return sklearn.metrics.accuracy_score(y, [x > 0 for x in yhat])\n", "\n", + "\n", "def validate_model(data, model, y_label, verbose=False, **kwargs):\n", " losses, y_true, y_hat = [], [], []\n", " correct_mol = 0\n", @@ -5558,22 +6045,29 @@ " model_kwargs = {}\n", " for key, value in kwargs.items():\n", " model_kwargs[key] = tf.cast(d[value], dtype=tf.float32)\n", - " logits = model(d.features, tf.cast(d.A / d.deg, dtype=tf.float32), **model_kwargs)\n", - " #logits = model(d.features, tf.cast(d.Id, dtype=tf.float32), **model_kwargs)\n", + " logits = model(\n", + " d.features, tf.cast(d.A / d.deg, dtype=tf.float32), **model_kwargs\n", + " )\n", + " # logits = model(d.features, tf.cast(d.Id, dtype=tf.float32), **model_kwargs)\n", " y_tensor = tf.cast(d[y_label], dtype=tf.float32)\n", " loss = tf.nn.sigmoid_cross_entropy_with_logits(y_tensor, logits)\n", "\n", " losses = np.append(losses, loss.numpy())\n", " y_hat += [float(x) > 0 for x in logits]\n", " y_true += [int(x) for x in y_tensor]\n", - " if y_hat[-len(logits):] == y_true[-len(logits):]:\n", - " correct_mol+=1\n", + " if y_hat[-len(logits) :] == y_true[-len(logits) :]:\n", + " correct_mol += 1\n", "\n", " acc, mean_loss = sklearn.metrics.accuracy_score(y_true, y_hat), np.mean(losses)\n", - " if verbose: print(\"Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f\" % (acc, mean_loss, correct_mol / data.shape[0]))\n", + " if verbose:\n", + " print(\n", + " \"Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f\"\n", + " % (acc, mean_loss, correct_mol / data.shape[0])\n", + " )\n", "\n", " return acc, mean_loss, correct_mol / data.shape[0]\n", "\n", + "\n", "def train_gnn(data):\n", "\n", " optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n", @@ -5582,9 +6076,11 @@ " for i in data_train.index:\n", " with tf.GradientTape() as t:\n", " d = data_train.loc[i]\n", - " #logits = gnn_pooling_model(d.features, tf.cast(d.Atilde, dtype=tf.float32))\n", + " # logits = gnn_pooling_model(d.features, tf.cast(d.Atilde, dtype=tf.float32))\n", " logits = gnn_model(d.features, tf.cast(d.A / d.deg, dtype=tf.float32))\n", - " loss = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(d.isAromatic, dtype=tf.float32), logits)\n", + " loss = tf.nn.sigmoid_cross_entropy_with_logits(\n", + " tf.cast(d.isAromatic, dtype=tf.float32), logits\n", + " )\n", "\n", " variables = t.watched_variables()\n", " gradients = t.gradient(loss, variables)\n", @@ -5592,7 +6088,9 @@ "\n", " # Validate loss/acc\n", " print(\"Epoch %s\" % epoch)\n", - " validate_model(data=data_val, model=gnn_model, y_label=\"isAromatic\", verbose=True)\n", + " validate_model(\n", + " data=data_val, model=gnn_model, y_label=\"isAromatic\", verbose=True\n", + " )\n", "\n", " return" ] @@ -5623,7 +6121,7 @@ "source": [ "print(\"Test GNN\")\n", "acc, loss, correct = validate_model(df_test, gnn_model, \"isAromatic\", verbose=True)\n", - "#print()" + "# print()" ] }, { @@ -5661,13 +6159,11 @@ }, "outputs": [], "source": [ - "\n", "yhat = gnn_model(d.features, tf.cast(d.A / d.deg, dtype=tf.float32))\n", "ytrue = tf.cast(d.isAromatic, dtype=tf.float32)\n", "\n", "print(\" y y_hat \")\n", - "print(np.round(np.array(tf.concat([ytrue, tf.nn.sigmoid(yhat)], axis=1)), decimals=2))\n", - "\n" + "print(np.round(np.array(tf.concat([ytrue, tf.nn.sigmoid(yhat)], axis=1)), decimals=2))" ] }, { @@ -5703,26 +6199,32 @@ "\n", " row = 0\n", " for j in range(deg.shape[0]):\n", - " row_degree = int(deg[j,0].numpy())\n", + " row_degree = int(deg[j, 0].numpy())\n", " for i in range(row_degree):\n", " edges_to = np.where(A[j] > 0.001)[1]\n", " AL[row + i, j] = 1.0\n", " AR[row + i, edges_to[i]] = 1.0\n", - " y_edge.append(G[j][edges_to[i]]['bond_type'].name == \"AROMATIC\") # Add y condition here\n", + " y_edge.append(\n", + " G[j][edges_to[i]][\"bond_type\"].name == \"AROMATIC\"\n", + " ) # Add y condition here\n", " edge_idx.append((j, edges_to[i]))\n", " row += row_degree\n", "\n", " return AL, AR, y_edge, edge_idx\n", "\n", + "\n", "def add_dataframe_edge_features(df):\n", - " df['AL'] = df.apply(lambda x: compute_helper_matrices(x['A'], x['deg'], x['graph']), axis=1)\n", - " df['AR'] = df['AL'].apply(lambda x: x[1])\n", - " df['edges_is_aromatic'] = df['AL'].apply(lambda x: np.array([x[2]]).T)\n", - " df['edges_idx'] = df['AL'].apply(lambda x: x[3])\n", - " df['AL'] = df['AL'].apply(lambda x: x[0])\n", + " df[\"AL\"] = df.apply(\n", + " lambda x: compute_helper_matrices(x[\"A\"], x[\"deg\"], x[\"graph\"]), axis=1\n", + " )\n", + " df[\"AR\"] = df[\"AL\"].apply(lambda x: x[1])\n", + " df[\"edges_is_aromatic\"] = df[\"AL\"].apply(lambda x: np.array([x[2]]).T)\n", + " df[\"edges_idx\"] = df[\"AL\"].apply(lambda x: x[3])\n", + " df[\"AL\"] = df[\"AL\"].apply(lambda x: x[0])\n", " return df\n", "\n", - "df_nist = add_dataframe_edge_features(df_nist)\n" + "\n", + "df_nist = add_dataframe_edge_features(df_nist)" ] }, { @@ -5737,18 +6239,22 @@ "source": [ "edge_prediction_layer = tf.keras.layers.Dense(1)\n", "\n", + "\n", "def edge_pred(features, AL, AR, transform, activation):\n", " X = tf.concat([tf.matmul(AL, features), tf.matmul(AR, features)], axis=1)\n", "\n", " return activation(transform(tf.cast(X, dtype=tf.float32)))\n", "\n", + "\n", "def edge_prediction_model(features, A, AL, AR):\n", " hidden_features = gnn(features, A, layer_self1, layer_conv1, tf.nn.relu)\n", " hidden_features = gnn(hidden_features, A, layer_self2, layer_conv2, tf.nn.relu)\n", " hidden_features = gnn(hidden_features, A, layer_self3, layer_conv3, tf.nn.relu)\n", " hidden_features = gnn(hidden_features, A, layer_self4, layer_conv4, tf.nn.relu)\n", " hidden_features = gnn(hidden_features, A, layer_self5, layer_conv5, tf.nn.relu)\n", - " output_logits = edge_pred(hidden_features, AL, AR, edge_prediction_layer, tf.identity)\n", + " output_logits = edge_pred(\n", + " hidden_features, AL, AR, edge_prediction_layer, tf.identity\n", + " )\n", " return output_logits\n", "\n", "\n", @@ -5758,14 +6264,19 @@ " data_train, data_val = train_test_split(data, test_size=0.1)\n", "\n", " for epoch in range(1, 6):\n", - " L=[]\n", + " L = []\n", " for i in data_train.index:\n", " with tf.GradientTape() as t:\n", " d = data.loc[i]\n", " if d.graph.number_of_edges() == 0:\n", " continue\n", - " logits = edge_prediction_model(d.features, tf.cast(d.A / d.deg, dtype=tf.float32), tf.cast(d.AL, dtype=tf.float32), tf.cast(d.AR, dtype=tf.float32))\n", - " #logits = edge_prediction_model(d.features, tf.cast(d.Id, dtype=tf.float32), tf.cast(d.AL, dtype=tf.float32), tf.cast(d.AR, dtype=tf.float32))\n", + " logits = edge_prediction_model(\n", + " d.features,\n", + " tf.cast(d.A / d.deg, dtype=tf.float32),\n", + " tf.cast(d.AL, dtype=tf.float32),\n", + " tf.cast(d.AR, dtype=tf.float32),\n", + " )\n", + " # logits = edge_prediction_model(d.features, tf.cast(d.Id, dtype=tf.float32), tf.cast(d.AL, dtype=tf.float32), tf.cast(d.AR, dtype=tf.float32))\n", " y_tensor = tf.cast(d.edges_is_aromatic, dtype=tf.float32)\n", " loss = tf.nn.sigmoid_cross_entropy_with_logits(y_tensor, logits)\n", " L.append(np.mean(loss))\n", @@ -5773,9 +6284,16 @@ " gradients = t.gradient(loss, variables)\n", " optimizer.apply_gradients(zip(gradients, variables))\n", "\n", - " # Validate loss/acc\n", + " # Validate loss/acc\n", " print(\"Epoch %s\" % epoch)\n", - " validate_model(data=data_val, model=edge_prediction_model, y_label=\"edges_is_aromatic\", verbose=True, AL=\"AL\", AR=\"AR\")\n", + " validate_model(\n", + " data=data_val,\n", + " model=edge_prediction_model,\n", + " y_label=\"edges_is_aromatic\",\n", + " verbose=True,\n", + " AL=\"AL\",\n", + " AR=\"AR\",\n", + " )\n", "\n", " return" ] @@ -5790,8 +6308,10 @@ }, "outputs": [], "source": [ - "df_train, df_test = train_test_split(df_nist, test_size=0.5) # Redo since df has been change\n", - "train_edge_gnn(df_train) #TODO test train" + "df_train, df_test = train_test_split(\n", + " df_nist, test_size=0.5\n", + ") # Redo since df has been change\n", + "train_edge_gnn(df_train) # TODO test train" ] }, { @@ -5805,7 +6325,14 @@ "outputs": [], "source": [ "print(\"Test Edge GNN\")\n", - "acc, loss, correct = validate_model(data=df_test, model=edge_prediction_model, y_label=\"edges_is_aromatic\", verbose=True, AL=\"AL\", AR=\"AR\")" + "acc, loss, correct = validate_model(\n", + " data=df_test,\n", + " model=edge_prediction_model,\n", + " y_label=\"edges_is_aromatic\",\n", + " verbose=True,\n", + " AL=\"AL\",\n", + " AR=\"AR\",\n", + ")" ] }, { @@ -5819,7 +6346,12 @@ "outputs": [], "source": [ "d = df_nist.iloc[4]\n", - "yhat = edge_prediction_model(d.features, tf.cast(d.A / d.deg, dtype=tf.float32), tf.cast(d.AL, dtype=tf.float32), tf.cast(d.AR, dtype=tf.float32))\n", + "yhat = edge_prediction_model(\n", + " d.features,\n", + " tf.cast(d.A / d.deg, dtype=tf.float32),\n", + " tf.cast(d.AL, dtype=tf.float32),\n", + " tf.cast(d.AR, dtype=tf.float32),\n", + ")\n", "ytrue = tf.cast(d.edges_is_aromatic, dtype=tf.float32)\n", "print(np.round(np.array(tf.concat([ytrue, tf.nn.sigmoid(yhat)], axis=1)), decimals=2))\n", "d.mol" @@ -5839,8 +6371,8 @@ "\n", "for i in range(tf.cast(d.A, dtype=tf.float32).shape[0]):\n", " for j in range(tf.cast(d.A, dtype=tf.float32).shape[1]):\n", - " if d.A[i,j] >= 1:\n", - " print(G[i][j]['bond'].GetBondType())" + " if d.A[i, j] >= 1:\n", + " print(G[i][j][\"bond\"].GetBondType())" ] }, { @@ -5853,25 +6385,26 @@ }, "outputs": [], "source": [ - "def break_bond(mol, i,j):\n", + "def break_bond(mol, i, j):\n", " em = Chem.EditableMol(mol)\n", " em.RemoveBond(i, j)\n", " new_mol = em.GetMol()\n", - " frags = Chem.GetMolFrags(new_mol,asMols=True)\n", + " frags = Chem.GetMolFrags(new_mol, asMols=True)\n", " return new_mol, frags\n", "\n", + "\n", "dummy = df_nist.iloc[12]\n", - "#print(dummy)\n", - "#dummy.mol\n", + "# print(dummy)\n", + "# dummy.mol\n", "\n", - "#print(dummy.A)\n", - "#print(dummy.Xsymbol)\n", + "# print(dummy.A)\n", + "# print(dummy.Xsymbol)\n", "draw_graph(dummy.graph, edge_labels=True)\n", "\n", "\n", "new_mol, fragments = break_bond(dummy.mol, 0, 6)\n", "\n", - "new_mol\n" + "new_mol" ] }, { @@ -5884,14 +6417,12 @@ }, "outputs": [], "source": [ - "\n", - "\n", "nicotine_mol = Chem.MolFromSmiles(\"NCCCCC(N)CC(O)=O\")\n", "print(Descriptors.ExactMolWt(nicotine_mol) + 1)\n", "G = mol_to_nx(nicotine_mol)\n", "draw_graph(G, edge_labels=True)\n", "new_mol, fragments = break_bond(nicotine_mol, 3, 4)\n", - "print([Descriptors.ExactMolWt(f) +1 for f in fragments])\n", + "print([Descriptors.ExactMolWt(f) + 1 for f in fragments])\n", "new_mol" ] }, @@ -5909,20 +6440,26 @@ "MZ_TOLERANCE = 0.1\n", "PROTON_MZ = 1.007\n", "d = df_nist.iloc[1]\n", + "\n", + "\n", "def getIonWeights(mol, y_idx, charge):\n", " weight = []\n", - " for i,j in y_idx:\n", - "\n", - " #[] TODO which weight is which fragment?\n", + " for i, j in y_idx:\n", + " # [] TODO which weight is which fragment?\n", " try:\n", " new_mol, fragments = break_bond(mol, int(i), int(j))\n", " except (Chem.AtomKekulizeException, Chem.KekulizeException):\n", - " #print(i,j, \"Error\", Chem.AtomKekulizeException)\n", + " # print(i,j, \"Error\", Chem.AtomKekulizeException)\n", " weight.append([np.nan, np.nan])\n", " else:\n", - " #print(i,j,[Descriptors.ExactMolWt(f) + charge for f in fragments])\n", + " # print(i,j,[Descriptors.ExactMolWt(f) + charge for f in fragments])\n", " if len(fragments) > 1:\n", - " weight.append([Descriptors.ExactMolWt(f) + (PROTON_MZ * charge) for f in fragments])\n", + " weight.append(\n", + " [\n", + " Descriptors.ExactMolWt(f) + (PROTON_MZ * charge)\n", + " for f in fragments\n", + " ]\n", + " )\n", " else:\n", " weight.append([np.nan, np.nan])\n", " return weight\n", @@ -5931,7 +6468,7 @@ "print(d.edges_idx)\n", "w = getIonWeights(d.mol, d.edges_idx, 1)\n", "print(w)\n", - "np.unique(w)\n" + "np.unique(w)" ] }, { @@ -5944,7 +6481,7 @@ }, "outputs": [], "source": [ - "draw_graph(d.graph,edge_labels=True)" + "draw_graph(d.graph, edge_labels=True)" ] }, { @@ -5972,9 +6509,9 @@ "# spectral library\n", "nist_msp = mspReader.read(f\"{home}/data/metabolites/MassBank/MassBank_NIST.msp\")\n", "df = pd.DataFrame(nist_msp)\n", - "df['mol'] = df['SMILES'].apply(Chem.MolFromSmiles)\n", + "df[\"mol\"] = df[\"SMILES\"].apply(Chem.MolFromSmiles)\n", "df.dropna(inplace=True)\n", - "print(df.shape)\n" + "print(df.shape)" ] }, { @@ -5987,13 +6524,11 @@ }, "outputs": [], "source": [ - "\n", - "\n", "d = df.loc[181]\n", "\n", "print(d)\n", "d.mol\n", - "#df_nist" + "# df_nist" ] }, { @@ -6016,7 +6551,7 @@ "print(df.Precursor_type.unique())\n", "print(df.Instrument_type.unique())\n", "print(df.Collision_energy.unique())\n", - "print(sum(df.Collision_energy==\"30(NCE)\"))\n", + "print(sum(df.Collision_energy == \"30(NCE)\"))\n", "sns.histplot(df.Collision_energy)\n", "plt.show()" ] @@ -6042,45 +6577,75 @@ }, "outputs": [], "source": [ - "df = df[df.Ion_mode == 'POSITIVE']\n", - "df = df[df.Precursor_type == '[M+H]+']\n", - "df = df[df.Spectrum_type == 'MS2']\n", + "df = df[df.Ion_mode == \"POSITIVE\"]\n", + "df = df[df.Precursor_type == \"[M+H]+\"]\n", + "df = df[df.Spectrum_type == \"MS2\"]\n", "df = df[df.Instrument_type == \"LC-ESI-ITFT\"]\n", - "df['Num Peaks'] = df['Num Peaks'].astype(int)\n", - "df = df[df['Num Peaks'] > 1]\n", + "df[\"Num Peaks\"] = df[\"Num Peaks\"].astype(int)\n", + "df = df[df[\"Num Peaks\"] > 1]\n", "print(df.shape)\n", - "#df_nist = df_nist.iloc[:1000] # Reduce to 1000\n", + "# df_nist = df_nist.iloc[:1000] # Reduce to 1000\n", "\n", "\n", "df = add_dataframe_features(df)\n", "df = add_dataframe_edge_features(df)\n", "\n", "\n", - "def NCE_to_eV(NCE, isolation_center, charge): # isolation_ceter = precursor mz (???)\n", - " return (NCE * isolation_center) / (500 * charge) # NCE seems to be calibrated to 500 m/z ions (???)\n", + "def NCE_to_eV(NCE, isolation_center, charge): # isolation_ceter = precursor mz (???)\n", + " return (NCE * isolation_center) / (\n", + " 500 * charge\n", + " ) # NCE seems to be calibrated to 500 m/z ions (???)\n", + "\n", "\n", "def eV_to_NCE(eV):\n", " return None\n", "\n", + "\n", "NCEs = []\n", "\n", - "print(any([abs((float(df.loc[x, 'ExactMass']) - Descriptors.ExactMolWt(df.loc[x, 'mol'])) > 0.2) for x in df.index]))\n", - "print(sum([abs((float(df.loc[x, 'PrecursorMZ']) - (Descriptors.ExactMolWt(df.loc[x, 'mol']) + PROTON_MZ)) > 0.2) for x in df.index]))\n", + "print(\n", + " any(\n", + " [\n", + " abs(\n", + " (\n", + " float(df.loc[x, \"ExactMass\"])\n", + " - Descriptors.ExactMolWt(df.loc[x, \"mol\"])\n", + " )\n", + " > 0.2\n", + " )\n", + " for x in df.index\n", + " ]\n", + " )\n", + ")\n", + "print(\n", + " sum(\n", + " [\n", + " abs(\n", + " (\n", + " float(df.loc[x, \"PrecursorMZ\"])\n", + " - (Descriptors.ExactMolWt(df.loc[x, \"mol\"]) + PROTON_MZ)\n", + " )\n", + " > 0.2\n", + " )\n", + " for x in df.index\n", + " ]\n", + " )\n", + ")\n", "\n", "#\n", "# Adjust types and assess specific features\n", "#\n", "\n", - "df['PrecursorMZ'] = df['PrecursorMZ'].astype('float32')\n", - "df['theoretical_PrecursorMZ'] = df['mol'].apply(lambda x: Descriptors.ExactMolWt(x) + PROTON_MZ)\n", - "df['PrecursorMZ_difference'] = df['PrecursorMZ'] - df['theoretical_PrecursorMZ']\n", - "df['absPrecursorMZ_difference'] = df['PrecursorMZ_difference'].apply(abs)\n", + "df[\"PrecursorMZ\"] = df[\"PrecursorMZ\"].astype(\"float32\")\n", + "df[\"theoretical_PrecursorMZ\"] = df[\"mol\"].apply(\n", + " lambda x: Descriptors.ExactMolWt(x) + PROTON_MZ\n", + ")\n", + "df[\"PrecursorMZ_difference\"] = df[\"PrecursorMZ\"] - df[\"theoretical_PrecursorMZ\"]\n", + "df[\"absPrecursorMZ_difference\"] = df[\"PrecursorMZ_difference\"].apply(abs)\n", "\n", "df = df[df.absPrecursorMZ_difference < MZ_TOLERANCE]\n", "\n", - "print(df.shape)\n", - "\n", - "\n" + "print(df.shape)" ] }, { @@ -6095,9 +6660,8 @@ "source": [ "d = df.loc[181]\n", "\n", - "#d['graph']\n", - "draw_graph(d.graph)\n", - "\n" + "# d['graph']\n", + "draw_graph(d.graph)" ] }, { @@ -6113,30 +6677,34 @@ "from treelib import Node, Tree\n", "from copy import copy\n", "\n", + "\n", "def create_fragments(mol, i, j):\n", " try:\n", " new_mol, fragments = break_bond(mol, int(i), int(j))\n", " except (Chem.AtomKekulizeException, Chem.KekulizeException):\n", - " #print(i,j, \"Error\", Chem.AtomKekulizeException)\n", + " # print(i,j, \"Error\", Chem.AtomKekulizeException)\n", " new_mol = None\n", " fragments = [None, None]\n", " else:\n", " if len(fragments) < 1:\n", - " #TODO resolve ring break\n", + " # TODO resolve ring break\n", " fragments = [fragments[0], None]\n", " return new_mol, fragments\n", "\n", + "\n", "#\n", "# TODO DOOOOOOOOOO\n", "#\n", "\n", + "\n", "def morganFinger(x):\n", " return AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024)\n", "\n", + "\n", "def equalMols(mol, other):\n", "\n", " funcs = [Chem.Descriptors.ExactMolWt, morganFinger, AllChem.GetMACCSKeysFingerprint]\n", - " #func = Chem.Descriptors.ExactMolWt # TODO add more here !!!!! When are mols equal????\n", + " # func = Chem.Descriptors.ExactMolWt # TODO add more here !!!!! When are mols equal????\n", " for func in funcs:\n", " if func(mol) == func(other):\n", " continue\n", @@ -6144,36 +6712,51 @@ " return False\n", " return True\n", "\n", + "\n", "def is_fragment_in_list(fragment, fragment_list):\n", " for f in fragment_list:\n", " if equalMols(fragment, f):\n", " return True\n", " return False\n", "\n", + "\n", "def build_fragmentation_tree(mol, edges, depth=2, parent_tree=None, parent_id=None):\n", " fragmentation_tree = Tree(tree=parent_tree)\n", " ID = fragmentation_tree.size()\n", - " fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(mol), parent=parent_id, identifier=ID, data=mol)\n", - "\n", + " fragmentation_tree.create_node(\n", + " tag=Chem.Descriptors.ExactMolWt(mol), parent=parent_id, identifier=ID, data=mol\n", + " )\n", "\n", " listed_fragments = []\n", "\n", - "\n", - " for i,j in edges:\n", + " for i, j in edges:\n", " _, fragments = create_fragments(mol, i, j)\n", " for f in fragments:\n", " if f is not None:\n", " if is_fragment_in_list(f, listed_fragments):\n", " continue\n", - " if depth == 1: # anchor\n", - " fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(f), identifier=fragmentation_tree.size(), parent=ID, data=f)\n", - " else: # recursion TODO OPTIMIZE\n", - " # build graph, adjacency matrix and index edges\n", + " if depth == 1: # anchor\n", + " fragmentation_tree.create_node(\n", + " tag=Chem.Descriptors.ExactMolWt(f),\n", + " identifier=fragmentation_tree.size(),\n", + " parent=ID,\n", + " data=f,\n", + " )\n", + " else: # recursion TODO OPTIMIZE\n", + " # build graph, adjacency matrix and index edges\n", " G = mol_to_nx(f)\n", " A = nx.convert_matrix.to_numpy_matrix(G)\n", - " deg = tf.transpose([tf.clip_by_value(tf.reduce_sum(A, axis=-1), 0.0001, 1000.0)])\n", - " _,_,_, edges = compute_helper_matrices(A, deg, G)\n", - " fragmentation_tree = build_fragmentation_tree(f, edges, depth=depth-1, parent_tree=fragmentation_tree, parent_id=ID)\n", + " deg = tf.transpose(\n", + " [tf.clip_by_value(tf.reduce_sum(A, axis=-1), 0.0001, 1000.0)]\n", + " )\n", + " _, _, _, edges = compute_helper_matrices(A, deg, G)\n", + " fragmentation_tree = build_fragmentation_tree(\n", + " f,\n", + " edges,\n", + " depth=depth - 1,\n", + " parent_tree=fragmentation_tree,\n", + " parent_id=ID,\n", + " )\n", "\n", " listed_fragments.append(f)\n", " return fragmentation_tree\n", @@ -6185,7 +6768,12 @@ "t.show(idhidden=False)\n", "print(t.size(level=0), t.size(level=1), t.size(level=2))\n", "\n", - "Chem.Draw.MolsToGridImage([d.mol, t.get_node(20).data, t.get_node(14).data, t.get_node(263).data], molsPerRow=4, useSVG=True, legends=[\"intact\", \"fragment 1\", \"fragment 2\", \"fragment 3\"])" + "Chem.Draw.MolsToGridImage(\n", + " [d.mol, t.get_node(20).data, t.get_node(14).data, t.get_node(263).data],\n", + " molsPerRow=4,\n", + " useSVG=True,\n", + " legends=[\"intact\", \"fragment 1\", \"fragment 2\", \"fragment 3\"],\n", + ")" ] }, { @@ -6198,22 +6786,26 @@ }, "outputs": [], "source": [ - "\n", "def do_peaks_match(peak1, peak2, mz_tolerance=MZ_TOLERANCE):\n", " return abs(peak1 - peak2) < mz_tolerance\n", - "#df_nist['num_peaks_matched'] = df_nist.apply(lambda x: sum([any([do_peaks_match(frag_peak, peak) for peak in x['peaks']['mz']]) for frag_peak in x['unique_fragment_mz']]), axis=1)\n", + "\n", + "\n", + "# df_nist['num_peaks_matched'] = df_nist.apply(lambda x: sum([any([do_peaks_match(frag_peak, peak) for peak in x['peaks']['mz']]) for frag_peak in x['unique_fragment_mz']]), axis=1)\n", + "\n", "\n", "def does_peak_match_any(peak, peak_list, mz_tolerance=MZ_TOLERANCE):\n", " return any([do_peaks_match(peak, peak2, mz_tolerance) for peak2 in peak_list])\n", "\n", + "\n", "def match_peak_to_list(peak, peak_list, mz_tolerance=MZ_TOLERANCE):\n", "\n", " for i in range(len(peak_list)):\n", " if do_peaks_match(peak, peak_list[i], mz_tolerance):\n", - " return i #TODO only reports first match ::TOODODODODO\n", + " return i # TODO only reports first match ::TOODODODODO\n", "\n", " return np.nan\n", "\n", + "\n", "def get_matching_peaks_until_depth(tree, peaks, depth):\n", " fragments = []\n", " peak_list = []\n", @@ -6223,22 +6815,31 @@ " if tree.level(node.identifier) > depth:\n", " continue\n", " mz = node.tag + PROTON_MZ\n", - " peak_idx = match_peak_to_list(mz, peaks['mz'])\n", + " peak_idx = match_peak_to_list(mz, peaks[\"mz\"])\n", " if not np.isnan(peak_idx):\n", - " fragments.append((mz, node.identifier, None if node.identifier==0 else tree.parent(node.identifier).identifier))\n", - " if not peaks['mz'][peak_idx] in peak_list:\n", - " peak_list.append(peaks['mz'][peak_idx])\n", - " intensities.append(peaks['intensity'][peak_idx])\n", + " fragments.append(\n", + " (\n", + " mz,\n", + " node.identifier,\n", + " None\n", + " if node.identifier == 0\n", + " else tree.parent(node.identifier).identifier,\n", + " )\n", + " )\n", + " if not peaks[\"mz\"][peak_idx] in peak_list:\n", + " peak_list.append(peaks[\"mz\"][peak_idx])\n", + " intensities.append(peaks[\"intensity\"][peak_idx])\n", "\n", " return fragments, peak_list, intensities\n", "\n", - "print(\"Spectrum peak list: %s\" % d.peaks['mz'])\n", + "\n", + "print(\"Spectrum peak list: %s\" % d.peaks[\"mz\"])\n", "DEPTH = 2\n", "frags, p, i = get_matching_peaks_until_depth(tree=t, peaks=d.peaks, depth=DEPTH)\n", "print(\"Identified [M+H]+ fragments at treedepth %s: %s\" % (DEPTH, p))\n", "print(\"Fragment intensities %s\" % i)\n", - "print(\"Fragment intensities covered %.02f\" % (sum(i) / sum(d.peaks['intensity'])))\n", - "#collections.Counter([x[2] for x in frags])" + "print(\"Fragment intensities covered %.02f\" % (sum(i) / sum(d.peaks[\"intensity\"])))\n", + "# collections.Counter([x[2] for x in frags])" ] }, { @@ -6253,7 +6854,12 @@ "source": [ "# 1st order fragmentation\n", "mols = [d.mol, t.get_node(227).data, t.get_node(149).data, t.get_node(263).data]\n", - "Chem.Draw.MolsToGridImage(mols, molsPerRow=4, useSVG=True, legends=[\"mz: %.2f\" % (Chem.Descriptors.ExactMolWt(x) + PROTON_MZ) for x in mols])" + "Chem.Draw.MolsToGridImage(\n", + " mols,\n", + " molsPerRow=4,\n", + " useSVG=True,\n", + " legends=[\"mz: %.2f\" % (Chem.Descriptors.ExactMolWt(x) + PROTON_MZ) for x in mols],\n", + ")" ] }, { @@ -6274,11 +6880,16 @@ " if do_peaks_match(x, f[0]):\n", " parents.append(f[2])\n", " print(\"%s: %s\" % (x, parents))\n", - "parents = [t.get_node(170).data ,t.get_node(263).data , t.get_node(227).data]\n", + "parents = [t.get_node(170).data, t.get_node(263).data, t.get_node(227).data]\n", "mols = [t.get_node(171).data, t.get_node(283).data, t.get_node(232).data]\n", "\n", - "m = parents+mols\n", - "Chem.Draw.MolsToGridImage(m, molsPerRow=3, useSVG=True, legends=[\"mz: %.2f\" % (Chem.Descriptors.ExactMolWt(x) + PROTON_MZ) for x in m])" + "m = parents + mols\n", + "Chem.Draw.MolsToGridImage(\n", + " m,\n", + " molsPerRow=3,\n", + " useSVG=True,\n", + " legends=[\"mz: %.2f\" % (Chem.Descriptors.ExactMolWt(x) + PROTON_MZ) for x in m],\n", + ")" ] }, { @@ -6291,8 +6902,8 @@ }, "outputs": [], "source": [ - "#s = mspReader.get_spectrum_by_name()\n", - "sv.plot_spectrum(d,title=\"Phacidin MS/MS\")" + "# s = mspReader.get_spectrum_by_name()\n", + "sv.plot_spectrum(d, title=\"Phacidin MS/MS\")" ] }, { @@ -6344,7 +6955,7 @@ "s = \"CCCCCCCCC(=O)C1=C(C=C(C(=C1O)C=O)O)O\"\n", "m = Chem.MolFromSmiles(s)\n", "\n", - "m\n" + "m" ] }, { @@ -6385,7 +6996,7 @@ "source": [ "t = build_fragmentation_tree(d.mol, d.edges_idx, depth=2)\n", "t.show(idhidden=False)\n", - "print(get_matching_peaks_until_depth(t,d.peaks,depth=2))\n", + "print(get_matching_peaks_until_depth(t, d.peaks, depth=2))\n", "t.get_node(56).data" ] }, @@ -6401,11 +7012,11 @@ "source": [ "print(zigzagerrorhack)\n", "\n", - "df['fragmentation_tree'] = None\n", - "df['identified_peaks'] = None\n", - "df['identified_peaks_intensities'] = None\n", - "df['intensity_covered'] = None\n", - "df['num_identified_peaks'] = None\n", + "df[\"fragmentation_tree\"] = None\n", + "df[\"identified_peaks\"] = None\n", + "df[\"identified_peaks_intensities\"] = None\n", + "df[\"intensity_covered\"] = None\n", + "df[\"num_identified_peaks\"] = None\n", "\n", "\n", "DEPTH = 2\n", @@ -6413,25 +7024,24 @@ "c = 0\n", "for x in df.index:\n", " if c % 20 == 0:\n", - " print(\"%.02f%%\" % (100 * c/float(df.shape[0])), end='\\r')\n", + " print(\"%.02f%%\" % (100 * c / float(df.shape[0])), end=\"\\r\")\n", " d = df.loc[x]\n", " t = build_fragmentation_tree(d.mol, d.edges_idx, depth=DEPTH)\n", " frags, p, i = get_matching_peaks_until_depth(tree=t, peaks=d.peaks, depth=DEPTH)\n", - " intensity_covered = sum(i) / sum(d.peaks['intensity'])\n", - "\n", + " intensity_covered = sum(i) / sum(d.peaks[\"intensity\"])\n", "\n", " # FILTER: at least 1 peak other than the precursor, at least 0.5 of the total intensity covered\n", " if len(p) > 1 and intensity_covered > 0.5:\n", - " df.at[x,'fragmentation_tree'] = t\n", - " #df_nist.loc[x,'fragmentation_tree'] = t\n", - " #df_nist.loc[x]['fragmentation_tree'] = t\n", - " df.at[x, 'identified_peaks'] = p\n", - " df.at[x,'identified_peaks_intensities'] = i\n", - " df.at[x,'intensity_covered'] = intensity_covered\n", - " df.at[x,'num_identified_peaks'] = len(p)\n", + " df.at[x, \"fragmentation_tree\"] = t\n", + " # df_nist.loc[x,'fragmentation_tree'] = t\n", + " # df_nist.loc[x]['fragmentation_tree'] = t\n", + " df.at[x, \"identified_peaks\"] = p\n", + " df.at[x, \"identified_peaks_intensities\"] = i\n", + " df.at[x, \"intensity_covered\"] = intensity_covered\n", + " df.at[x, \"num_identified_peaks\"] = len(p)\n", " else:\n", " to_be_removed.append(x)\n", - " c+=1" + " c += 1" ] }, { @@ -6446,8 +7056,8 @@ "source": [ "print(zigzagerrorhack)\n", "\n", - "#df_nist.to_csv('./nist_reduced.csv')\n", - "df.to_csv(f'{home}/Desktop/nist_reduced.csv')\n", + "# df_nist.to_csv('./nist_reduced.csv')\n", + "df.to_csv(f\"{home}/Desktop/nist_reduced.csv\")\n", "\n", "print(\"yes\")\n", "print(zigzagerrorhack)" @@ -6475,7 +7085,8 @@ "outputs": [], "source": [ "import ast\n", - "df = pd.read_csv('./nist_reduced.csv')\n", + "\n", + "df = pd.read_csv(\"./nist_reduced.csv\")\n", "\n", "df.head()" ] @@ -6490,11 +7101,13 @@ }, "outputs": [], "source": [ - "df = df[~df['intensity_covered'].apply(np.isnan)]\n", - "df['peaks'] = df['peaks'].apply(ast.literal_eval)\n", - "df['identified_peaks'] = df['identified_peaks'].apply(ast.literal_eval)\n", - "df['identified_peaks_intensities'] = df['identified_peaks_intensities'].apply(ast.literal_eval)\n", - "print(df.shape)\n" + "df = df[~df[\"intensity_covered\"].apply(np.isnan)]\n", + "df[\"peaks\"] = df[\"peaks\"].apply(ast.literal_eval)\n", + "df[\"identified_peaks\"] = df[\"identified_peaks\"].apply(ast.literal_eval)\n", + "df[\"identified_peaks_intensities\"] = df[\"identified_peaks_intensities\"].apply(\n", + " ast.literal_eval\n", + ")\n", + "print(df.shape)" ] }, { @@ -6507,7 +7120,6 @@ }, "outputs": [], "source": [ - "\n", "def find_precursor_peak_idx(precursor_mz, peaks):\n", " for i in range(len(peaks)):\n", " if do_peaks_match(peaks[i], precursor_mz):\n", @@ -6515,11 +7127,31 @@ " return -1\n", "\n", "\n", - "df['precursor_peak_idx'] = df.apply(lambda x: find_precursor_peak_idx(x['PrecursorMZ'], x['peaks']['mz']) , axis=1)\n", - "df['intensity_covered_without_precursor'] = df.apply(lambda x: (sum(x['identified_peaks_intensities']) - x['peaks']['intensity'][x['precursor_peak_idx']]) / (sum(x['peaks']['intensity']) - x['peaks']['intensity'][x['precursor_peak_idx']]), axis=1)\n", + "df[\"precursor_peak_idx\"] = df.apply(\n", + " lambda x: find_precursor_peak_idx(x[\"PrecursorMZ\"], x[\"peaks\"][\"mz\"]), axis=1\n", + ")\n", + "df[\"intensity_covered_without_precursor\"] = df.apply(\n", + " lambda x: (\n", + " (\n", + " sum(x[\"identified_peaks_intensities\"])\n", + " - x[\"peaks\"][\"intensity\"][x[\"precursor_peak_idx\"]]\n", + " )\n", + " / (\n", + " sum(x[\"peaks\"][\"intensity\"])\n", + " - x[\"peaks\"][\"intensity\"][x[\"precursor_peak_idx\"]]\n", + " )\n", + " ),\n", + " axis=1,\n", + ")\n", "x = df.iloc[0]\n", - "#print((x['identified_peaks_intensities'][0]))\n", - "print((sum(x['identified_peaks_intensities']) - x['peaks']['intensity'][x['precursor_peak_idx']]) / (sum(x['peaks']['intensity']) - x['peaks']['intensity'][x['precursor_peak_idx']]))#\n", + "# print((x['identified_peaks_intensities'][0]))\n", + "print(\n", + " (\n", + " sum(x[\"identified_peaks_intensities\"])\n", + " - x[\"peaks\"][\"intensity\"][x[\"precursor_peak_idx\"]]\n", + " )\n", + " / (sum(x[\"peaks\"][\"intensity\"]) - x[\"peaks\"][\"intensity\"][x[\"precursor_peak_idx\"]])\n", + ") #\n", "print(x.peaks, x.identified_peaks, x.PrecursorMZ, x.intensity_covered_without_precursor)" ] }, @@ -6540,17 +7172,23 @@ "sns.boxplot(df.intensity_covered, color=sns.color_palette(\"Paired\")[0])\n", "plt.show()\n", "\n", - "plt.figure(figsize=(12,4))\n", - "sns.boxplot(df.intensity_covered_without_precursor, color=sns.color_palette(\"Paired\")[0])\n", + "plt.figure(figsize=(12, 4))\n", + "sns.boxplot(\n", + " df.intensity_covered_without_precursor, color=sns.color_palette(\"Paired\")[0]\n", + ")\n", "plt.show()\n", "\n", - "fig, axs = plt.subplots(2, 1, sharex='all', figsize=(8,6))\n", + "fig, axs = plt.subplots(2, 1, sharex=\"all\", figsize=(8, 6))\n", "sns.boxplot(df.intensity_covered, color=sns.color_palette(\"Paired\")[0], ax=axs[0])\n", - "sns.boxplot(df.intensity_covered_without_precursor, color=sns.color_palette(\"Paired\")[2], ax=axs[1])\n", + "sns.boxplot(\n", + " df.intensity_covered_without_precursor,\n", + " color=sns.color_palette(\"Paired\")[2],\n", + " ax=axs[1],\n", + ")\n", "\n", "plt.show()\n", "\n", - "ax = sns.boxplot(x=df['Num Peaks'], color=sns.color_palette(\"Paired\")[0])\n", + "ax = sns.boxplot(x=df[\"Num Peaks\"], color=sns.color_palette(\"Paired\")[0])\n", "ax.set_xlim([1, 50])\n", "plt.show()\n", "\n", @@ -6567,7 +7205,9 @@ }, "outputs": [], "source": [ - "df['fragmentation_tree'] = df.apply(lambda x: build_fragmentation_tree(x['mol'], x['edges_idx'], depth=2), axis=1)" + "df[\"fragmentation_tree\"] = df.apply(\n", + " lambda x: build_fragmentation_tree(x[\"mol\"], x[\"edges_idx\"], depth=2), axis=1\n", + ")" ] }, { @@ -6580,15 +7220,25 @@ }, "outputs": [], "source": [ - "df['matching_peaks_d1'] = df.apply(lambda x: get_matching_peaks_until_depth(x['fragmentation_tree'], x['peaks'], 1), axis=1)\n", - "df['matching_peaks_d2'] = df.apply(lambda x: get_matching_peaks_until_depth(x['fragmentation_tree'], x['peaks'], 2), axis=1)\n", + "df[\"matching_peaks_d1\"] = df.apply(\n", + " lambda x: get_matching_peaks_until_depth(x[\"fragmentation_tree\"], x[\"peaks\"], 1),\n", + " axis=1,\n", + ")\n", + "df[\"matching_peaks_d2\"] = df.apply(\n", + " lambda x: get_matching_peaks_until_depth(x[\"fragmentation_tree\"], x[\"peaks\"], 2),\n", + " axis=1,\n", + ")\n", "\n", - "df['num_peaks_matched_d1'] = df['matching_peaks_d1'].apply(lambda x: len(x[1]))\n", - "df['num_peaks_matched_d2'] = df['matching_peaks_d2'].apply(lambda x: len(x[1]))\n", + "df[\"num_peaks_matched_d1\"] = df[\"matching_peaks_d1\"].apply(lambda x: len(x[1]))\n", + "df[\"num_peaks_matched_d2\"] = df[\"matching_peaks_d2\"].apply(lambda x: len(x[1]))\n", "\n", - "df['intensity_covered_d1'] = df.apply(lambda x: sum(x['matching_peaks_d1'][2]) / sum(x['peaks']['intensity']) , axis=1)\n", - "df['intensity_covered_d2'] = df.apply(lambda x: sum(x['matching_peaks_d2'][2]) / sum(x['peaks']['intensity']) , axis=1)\n", - "#df_nist[['num_peaks_matched_d1', 'intensity_covered_d1', 'num_peaks_matched_d2', 'intensity_covered_d2']]\n" + "df[\"intensity_covered_d1\"] = df.apply(\n", + " lambda x: sum(x[\"matching_peaks_d1\"][2]) / sum(x[\"peaks\"][\"intensity\"]), axis=1\n", + ")\n", + "df[\"intensity_covered_d2\"] = df.apply(\n", + " lambda x: sum(x[\"matching_peaks_d2\"][2]) / sum(x[\"peaks\"][\"intensity\"]), axis=1\n", + ")\n", + "# df_nist[['num_peaks_matched_d1', 'intensity_covered_d1', 'num_peaks_matched_d2', 'intensity_covered_d2']]" ] }, { @@ -6601,18 +7251,18 @@ }, "outputs": [], "source": [ - "fig, axs = plt.subplots(1, 2, sharey='all', figsize=(8,6))\n", - "sns.boxplot(ax=axs[0], y=df['num_peaks_matched_d1'])\n", - "sns.boxplot(ax=axs[1], y=df['num_peaks_matched_d2'])\n", + "fig, axs = plt.subplots(1, 2, sharey=\"all\", figsize=(8, 6))\n", + "sns.boxplot(ax=axs[0], y=df[\"num_peaks_matched_d1\"])\n", + "sns.boxplot(ax=axs[1], y=df[\"num_peaks_matched_d2\"])\n", "plt.show()\n", "\n", "\n", - "fig, axs = plt.subplots(1, 2, sharey='all', figsize=(8,6))\n", - "sns.boxplot(ax=axs[0], y=df['intensity_covered_d1'], color=\"pink\")\n", - "sns.boxplot(ax=axs[1], y=df['intensity_covered_d2'], color=\"pink\")\n", + "fig, axs = plt.subplots(1, 2, sharey=\"all\", figsize=(8, 6))\n", + "sns.boxplot(ax=axs[0], y=df[\"intensity_covered_d1\"], color=\"pink\")\n", + "sns.boxplot(ax=axs[1], y=df[\"intensity_covered_d2\"], color=\"pink\")\n", "plt.show()\n", "\n", - "print(sum(df[df['intensity_covered_d2'] > 0.5]['num_peaks_matched_d2'] > 1))" + "print(sum(df[df[\"intensity_covered_d2\"] > 0.5][\"num_peaks_matched_d2\"] > 1))" ] }, { @@ -6625,16 +7275,15 @@ }, "outputs": [], "source": [ - "\n", - "sns.histplot(df, x='num_peaks_matched', bins=range(0,12))\n", + "sns.histplot(df, x=\"num_peaks_matched\", bins=range(0, 12))\n", "plt.show()\n", - "#print(sum(df_nist['num_peaks_matched'] >= 3))\n", + "# print(sum(df_nist['num_peaks_matched'] >= 3))\n", "\n", - "#sns.histplot(df_nist, x='Num Peaks', bins=range(0,50))\n", - "#plt.show()\n", - "df_candidates = df#df_nist[df_nist['num_peaks_matched'] >= 3]\n", - "#sns.histplot(df_candidates, x='Num Peaks', bins=range(0,50))\n", - "#plt.show()" + "# sns.histplot(df_nist, x='Num Peaks', bins=range(0,50))\n", + "# plt.show()\n", + "df_candidates = df # df_nist[df_nist['num_peaks_matched'] >= 3]\n", + "# sns.histplot(df_candidates, x='Num Peaks', bins=range(0,50))\n", + "# plt.show()" ] }, { @@ -6647,11 +7296,11 @@ }, "outputs": [], "source": [ - "#df_candidates['high_intense'] = df_candidates.apply(lambda x: sum([any([(do_peaks_match(frag_peak, float(x['peaks']['mz'][i])) and int(x['peaks']['intensity'][i]) > 100) for i in range(len(x['peaks']['mz']))]) for frag_peak in x['unique_fragment_mz']]), axis=1)\n", + "# df_candidates['high_intense'] = df_candidates.apply(lambda x: sum([any([(do_peaks_match(frag_peak, float(x['peaks']['mz'][i])) and int(x['peaks']['intensity'][i]) > 100) for i in range(len(x['peaks']['mz']))]) for frag_peak in x['unique_fragment_mz']]), axis=1)\n", "\n", - "df_candidates = df_candidates[df_candidates['high_intense'] >= 2]\n", + "df_candidates = df_candidates[df_candidates[\"high_intense\"] >= 2]\n", "\n", - "sns.histplot(df_candidates, x='high_intense', bins=range(0,12))\n", + "sns.histplot(df_candidates, x=\"high_intense\", bins=range(0, 12))\n", "plt.show()" ] }, @@ -6682,15 +7331,25 @@ "\n", "matches = []\n", "for i in range(len(d.edges_idx)):\n", - " f1 = any([do_peaks_match(d.edge_fragment_mz[i][0], float(peak)) for peak in d['peaks']['mz']])\n", - " f2 = any([do_peaks_match(d.edge_fragment_mz[i][1], float(peak)) for peak in d['peaks']['mz']])\n", + " f1 = any(\n", + " [\n", + " do_peaks_match(d.edge_fragment_mz[i][0], float(peak))\n", + " for peak in d[\"peaks\"][\"mz\"]\n", + " ]\n", + " )\n", + " f2 = any(\n", + " [\n", + " do_peaks_match(d.edge_fragment_mz[i][1], float(peak))\n", + " for peak in d[\"peaks\"][\"mz\"]\n", + " ]\n", + " )\n", "\n", " print(d.edges_idx[i], d.edge_fragment_mz[i], f1, f2)\n", " if f1:\n", " matches.append((d.edges_idx[i], d.edge_fragment_mz[i][0]))\n", " if f2:\n", " matches.append((d.edges_idx[i], d.edge_fragment_mz[i][1]))\n", - "#draw_graph(d.graph, edge_labels=True)\n", + "# draw_graph(d.graph, edge_labels=True)\n", "print(d)\n", "d.mol" ] @@ -6706,9 +7365,14 @@ "outputs": [], "source": [ "print(matches)\n", - "nm, f = break_bond(d.mol, 1,2)\n", + "nm, f = break_bond(d.mol, 1, 2)\n", "\n", - "Chem.Draw.MolsToGridImage([d.mol, nm, f[0], f[1]], molsPerRow=4, useSVG=True, legends=[\"intact\", \"broken\", \"fragment 1\", \"fragment 2\"])\n" + "Chem.Draw.MolsToGridImage(\n", + " [d.mol, nm, f[0], f[1]],\n", + " molsPerRow=4,\n", + " useSVG=True,\n", + " legends=[\"intact\", \"broken\", \"fragment 1\", \"fragment 2\"],\n", + ")" ] }, { @@ -6721,9 +7385,14 @@ }, "outputs": [], "source": [ - "nm, f = break_bond(d.mol, 6,7)\n", + "nm, f = break_bond(d.mol, 6, 7)\n", "\n", - "Chem.Draw.MolsToGridImage([d.mol, nm, f[0], f[1]], molsPerRow=4, useSVG=True, legends=[\"intact\", \"broken\", \"fragment 1\", \"fragment 2\"])" + "Chem.Draw.MolsToGridImage(\n", + " [d.mol, nm, f[0], f[1]],\n", + " molsPerRow=4,\n", + " useSVG=True,\n", + " legends=[\"intact\", \"broken\", \"fragment 1\", \"fragment 2\"],\n", + ")" ] }, { @@ -6749,9 +7418,10 @@ }, "outputs": [], "source": [ - "\n", "fragmentation_tree = Tree()\n", - "fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(d.mol), identifier=0, data=d.mol)\n", + "fragmentation_tree.create_node(\n", + " tag=Chem.Descriptors.ExactMolWt(d.mol), identifier=0, data=d.mol\n", + ")\n", "\n", "# First order fragmentations\n", "c = 1\n", @@ -6759,7 +7429,7 @@ "#\n", "# TODO Extremely Ugly - Complete Overhaul Needed\n", "#\n", - "for i,j in d.edges_idx:\n", + "for i, j in d.edges_idx:\n", " _, fragments = create_fragments(d.mol, i, j)\n", " for f in fragments:\n", " if f is not None:\n", @@ -6769,8 +7439,10 @@ " has_fragment = True\n", " break\n", " if not has_fragment:\n", - " fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(f), identifier=c, parent=0, data=f)\n", - " c+=1\n", + " fragmentation_tree.create_node(\n", + " tag=Chem.Descriptors.ExactMolWt(f), identifier=c, parent=0, data=f\n", + " )\n", + " c += 1\n", "\n", "for node in fragmentation_tree.children(0):\n", " parent_id = node.identifier\n", @@ -6778,8 +7450,8 @@ " G = mol_to_nx(MOL)\n", " A = nx.convert_matrix.to_numpy_matrix(G)\n", " deg = tf.transpose([tf.clip_by_value(tf.reduce_sum(A, axis=-1), 0.0001, 1000.0)])\n", - " _,_,_, edge_idx = compute_helper_matrices(A, deg, G)\n", - " for i,j in edge_idx:\n", + " _, _, _, edge_idx = compute_helper_matrices(A, deg, G)\n", + " for i, j in edge_idx:\n", " _, fragments = create_fragments(MOL, i, j)\n", " for f in fragments:\n", " if f is not None:\n", @@ -6789,16 +7461,31 @@ " has_fragment = True\n", " break\n", " if not has_fragment:\n", - " fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(f), identifier=c, parent=parent_id, data=f)\n", - " c+=1\n", + " fragmentation_tree.create_node(\n", + " tag=Chem.Descriptors.ExactMolWt(f),\n", + " identifier=c,\n", + " parent=parent_id,\n", + " data=f,\n", + " )\n", + " c += 1\n", "\n", "\n", - "#fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(f[0]), parent=0, data=Chem.Descriptors.ExactMolWt(f[0]))\n", - "#fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(f[1]), parent=0, data=Chem.Descriptors.ExactMolWt(f[1]))\n", + "# fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(f[0]), parent=0, data=Chem.Descriptors.ExactMolWt(f[0]))\n", + "# fragmentation_tree.create_node(tag=Chem.Descriptors.ExactMolWt(f[1]), parent=0, data=Chem.Descriptors.ExactMolWt(f[1]))\n", "\n", "fragmentation_tree.show(idhidden=False)\n", "\n", - "Chem.Draw.MolsToGridImage([d.mol, fragmentation_tree.get_node(25).data, fragmentation_tree.get_node(2).data, fragmentation_tree.get_node(17).data], molsPerRow=4, useSVG=True, legends=[\"intact\", \"fragment 1\", \"fragment 2\", \"fragment 3\"])\n" + "Chem.Draw.MolsToGridImage(\n", + " [\n", + " d.mol,\n", + " fragmentation_tree.get_node(25).data,\n", + " fragmentation_tree.get_node(2).data,\n", + " fragmentation_tree.get_node(17).data,\n", + " ],\n", + " molsPerRow=4,\n", + " useSVG=True,\n", + " legends=[\"intact\", \"fragment 1\", \"fragment 2\", \"fragment 3\"],\n", + ")" ] }, { @@ -6830,7 +7517,7 @@ "for node in fragmentation_tree.all_nodes():\n", " f = node.data\n", " mz = Chem.Descriptors.ExactMolWt(f) + PROTON_MZ\n", - " if any([do_peaks_match(mz, float(peak)) for peak in d['peaks']['mz']]):\n", + " if any([do_peaks_match(mz, float(peak)) for peak in d[\"peaks\"][\"mz\"]]):\n", " if node.identifier == 0:\n", " precursor_mz.append(mz)\n", " elif fragmentation_tree.parent(node.identifier).identifier == 0:\n", @@ -6844,7 +7531,12 @@ "print(np.unique(precursor_mz))\n", "print(np.unique(order1_fragments))\n", "print(np.unique(order2_fragments))\n", - "Chem.Draw.MolsToGridImage(o1_f, molsPerRow=4, useSVG=True, legends=[\"intact\", \"fragment 1\", \"fragment 2\", \"fragment 3\"])" + "Chem.Draw.MolsToGridImage(\n", + " o1_f,\n", + " molsPerRow=4,\n", + " useSVG=True,\n", + " legends=[\"intact\", \"fragment 1\", \"fragment 2\", \"fragment 3\"],\n", + ")" ] }, { @@ -6857,8 +7549,8 @@ }, "outputs": [], "source": [ - "print(d.peaks['mz'])\n", - "print(d.peaks['intensity'])\n", + "print(d.peaks[\"mz\"])\n", + "print(d.peaks[\"intensity\"])\n", "print(np.unique(parents))" ] }, diff --git a/lib_loader/ms_dial_loader.ipynb b/lib_loader/ms_dial_loader.ipynb index ed38a7a..f261461 100644 --- a/lib_loader/ms_dial_loader.ipynb +++ b/lib_loader/ms_dial_loader.ipynb @@ -15,7 +15,8 @@ ], "source": [ "import sys\n", - "print(f'Working with Python {sys.version}')\n", + "\n", + "print(f\"Working with Python {sys.version}\")\n", "\n", "import pandas as pd\n", "from rdkit import Chem\n", @@ -26,20 +27,22 @@ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import os\n", - "#import pymzml\n", + "\n", + "# import pymzml\n", "import numpy as np\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", + "RDLogger.DisableLog(\"rdApp.*\")\n", "\n", "\n", "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "import fiora.IO.mspReader as mspReader\n", "import fiora.IO.mgfReader as mgfReader\n", - "import fiora.visualization.spectrum_visualizer as sv\n" + "import fiora.visualization.spectrum_visualizer as sv" ] }, { @@ -48,10 +51,17 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "p = pd.DataFrame(mspReader.read(f\"{home}/data/metabolites/MS_DIAL/MSMS_Public_EXP_Pos_VS17.msp\", sep=\"\\t\"))\n", - "n = pd.DataFrame(mspReader.read(f\"{home}/data/metabolites/MS_DIAL/MSMS_Public_EXP_NEG_VS17.msp\", sep=\"\\t\"))\n", - "df = pd.concat([p, n])\n" + "p = pd.DataFrame(\n", + " mspReader.read(\n", + " f\"{home}/data/metabolites/MS_DIAL/MSMS_Public_EXP_Pos_VS17.msp\", sep=\"\\t\"\n", + " )\n", + ")\n", + "n = pd.DataFrame(\n", + " mspReader.read(\n", + " f\"{home}/data/metabolites/MS_DIAL/MSMS_Public_EXP_NEG_VS17.msp\", sep=\"\\t\"\n", + " )\n", + ")\n", + "df = pd.concat([p, n])" ] }, { @@ -104,7 +114,6 @@ } ], "source": [ - "\n", "df[df[\"CE_float\"]][\"INSTRUMENTTYPE\"].value_counts()" ] }, @@ -119,7 +128,7 @@ " try:\n", " df.at[i, \"origin\"] = d[\"COMMENT\"].split(\"origin=\")[1]\n", " except:\n", - " if \"MetaboBASE\" in d[\"COMMENT\"]: \n", + " if \"MetaboBASE\" in d[\"COMMENT\"]:\n", " df.at[i, \"origin\"] = \"MetaboBASE\"" ] }, @@ -185,7 +194,11 @@ } ], "source": [ - "sns.histplot(df[df[\"origin\"] == \"Vaniya/Fiehn Natural Products Library\"][\"RETENTIONTIME\"].astype(float))" + "sns.histplot(\n", + " df[df[\"origin\"] == \"Vaniya/Fiehn Natural Products Library\"][\"RETENTIONTIME\"].astype(\n", + " float\n", + " )\n", + ")" ] }, { @@ -204,7 +217,7 @@ "metadata": {}, "outputs": [], "source": [ - "#df[\"SMILES\"].apply(Chem.MolFromSmiles).isna().any()" + "# df[\"SMILES\"].apply(Chem.MolFromSmiles).isna().any()" ] }, { @@ -236,10 +249,19 @@ } ], "source": [ - "potential_homogenous_RT_libs = [\"RIKEN Plant Specialized Metabolome Annotation (PlaSMA) Authentic Standard Library\", 'BMDMS-NP'] #, 'Vaniya/Fiehn Natural Products Library']#, \"Global Natural Product Social Molecular Networking Library\"]\n", + "potential_homogenous_RT_libs = [\n", + " \"RIKEN Plant Specialized Metabolome Annotation (PlaSMA) Authentic Standard Library\",\n", + " \"BMDMS-NP\",\n", + "] # , 'Vaniya/Fiehn Natural Products Library']#, \"Global Natural Product Social Molecular Networking Library\"]\n", "\n", - "sns.histplot(data=df[df[\"origin\"].isin(potential_homogenous_RT_libs)], x=\"RETENTIONTIME\", hue=\"origin\", common_norm=False, stat=\"density\")\n", - "plt.xlim([0,20])" + "sns.histplot(\n", + " data=df[df[\"origin\"].isin(potential_homogenous_RT_libs)],\n", + " x=\"RETENTIONTIME\",\n", + " hue=\"origin\",\n", + " common_norm=False,\n", + " stat=\"density\",\n", + ")\n", + "plt.xlim([0, 20])" ] }, { @@ -272,7 +294,9 @@ } ], "source": [ - "sns.histplot(df[\"SMILES\"], )\n", + "sns.histplot(\n", + " df[\"SMILES\"],\n", + ")\n", "df[\"origin\"].value_counts()" ] }, @@ -346,7 +370,7 @@ "source": [ "precursor_types = [\"[M+H]+\", \"[M-H]-\"]\n", "\n", - "df = df[df[\"PRECURSORTYPE\"].apply(lambda x: x in precursor_types)]\n" + "df = df[df[\"PRECURSORTYPE\"].apply(lambda x: x in precursor_types)]" ] }, { @@ -358,7 +382,12 @@ "from modules.MOL.collision_energy import align_CE\n", "\n", "df[\"PRECURSORMZ\"] = df[\"PRECURSORMZ\"].astype(float)\n", - "df[\"CE\"] = df.apply(lambda x: align_CE(x[\"COLLISIONENERGY\"], x[\"PRECURSORMZ\"], instrument=x[\"INSTRUMENTTYPE\"]), axis=1)" + "df[\"CE\"] = df.apply(\n", + " lambda x: align_CE(\n", + " x[\"COLLISIONENERGY\"], x[\"PRECURSORMZ\"], instrument=x[\"INSTRUMENTTYPE\"]\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -380,7 +409,7 @@ "source": [ "df[\"CE_type\"] = df[\"CE\"].apply(type)\n", "print(df[\"CE_type\"].value_counts())\n", - "print(sum(df[\"CE\"] == \"\")) # TODO assign random CE or 35 NCE ??)" + "print(sum(df[\"CE\"] == \"\")) # TODO assign random CE or 35 NCE ??)" ] }, { @@ -489,7 +518,7 @@ "df = df[df[\"CE_type\"] == float]\n", "df[\"CE\"] = df[\"CE\"].astype(float)\n", "df = df[df[\"CE\"] <= 1000.0]\n", - "df.reset_index(inplace=True)\n" + "df.reset_index(inplace=True)" ] }, { @@ -518,7 +547,6 @@ } ], "source": [ - "\n", "sns.displot(data=df, x=\"CE\", binwidth=5, kde=False)\n", "plt.xlim([0, 150])\n", "plt.show()\n", @@ -554,8 +582,15 @@ "fig, axs = plt.subplots(1, 1, figsize=(12.8, 6.4), sharey=False)\n", "\n", "\n", - "top_instrumenttypes = df['INSTRUMENTTYPE'].value_counts().head(6).index\n", - "sns.histplot(data=df[df['INSTRUMENTTYPE'].isin(top_instrumenttypes)], x=\"CE\", hue=\"INSTRUMENTTYPE\", multiple=\"stack\", binwidth=5, kde=False)\n", + "top_instrumenttypes = df[\"INSTRUMENTTYPE\"].value_counts().head(6).index\n", + "sns.histplot(\n", + " data=df[df[\"INSTRUMENTTYPE\"].isin(top_instrumenttypes)],\n", + " x=\"CE\",\n", + " hue=\"INSTRUMENTTYPE\",\n", + " multiple=\"stack\",\n", + " binwidth=5,\n", + " kde=False,\n", + ")\n", "plt.xlim([0, 150])\n", "plt.show()\n", "print(df.shape)" @@ -569,7 +604,7 @@ "source": [ "from modules.MOL.Metabolite import Metabolite\n", "\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n" + "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)" ] }, { @@ -600,7 +635,10 @@ "source": [ "from modules.MOL.constants import ADDUCT_WEIGHTS, PPM\n", "\n", - "df[\"Precursor_offset\"] = df[\"PRECURSORMZ\"] - df.apply(lambda x: x[\"Metabolite\"].ExactMolWeight + ADDUCT_WEIGHTS[x[\"PRECURSORTYPE\"]], axis=1)\n", + "df[\"Precursor_offset\"] = df[\"PRECURSORMZ\"] - df.apply(\n", + " lambda x: x[\"Metabolite\"].ExactMolWeight + ADDUCT_WEIGHTS[x[\"PRECURSORTYPE\"]],\n", + " axis=1,\n", + ")\n", "df[\"Precursor_abs_error\"] = abs(df[\"Precursor_offset\"])\n", "df[\"Precursor_rel_error\"] = df[\"Precursor_abs_error\"] / df[\"PRECURSORMZ\"]\n", "df[\"Precursor_ppm_error\"] = df[\"Precursor_abs_error\"] / (df[\"PRECURSORMZ\"] * PPM)\n", @@ -620,6 +658,7 @@ "%%capture\n", "from modules.MOL.Metabolite import Metabolite\n", "from modules.MOL.constants import PPM\n", + "\n", "TOLERANCE = 200 * PPM\n", "\n", "\n", @@ -628,7 +667,12 @@ "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes())\n", "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL())\n", - "df.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE), axis=1)" + "df.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -641,7 +685,7 @@ "# fragments = Chem.GetMolFrags(d[\"Metabolite\"].MOL, asMols=True)\n", "# if len(fragments) == 1:\n", "# d[\"Metabolite\"].fragment_MOL()\n", - "# else: \n", + "# else:\n", "# try:\n", "# d[\"Metabolite\"].fragment_MOL()\n", "# print(\"Unexpected\")\n", @@ -658,10 +702,19 @@ "source": [ "# Define figure styles\n", "color_palette = sns.color_palette(\"magma_r\", 8)\n", - "sns.set_theme(style=\"whitegrid\",\n", - " rc={'axes.edgecolor': 'black', 'ytick.left': True, 'xtick.bottom': True, 'xtick.color': 'black',\n", - " \"axes.spines.bottom\": True, \"axes.spines.right\": True, \"axes.spines.top\": True,\n", - " \"axes.spines.left\": True})" + "sns.set_theme(\n", + " style=\"whitegrid\",\n", + " rc={\n", + " \"axes.edgecolor\": \"black\",\n", + " \"ytick.left\": True,\n", + " \"xtick.bottom\": True,\n", + " \"xtick.color\": \"black\",\n", + " \"axes.spines.bottom\": True,\n", + " \"axes.spines.right\": True,\n", + " \"axes.spines.top\": True,\n", + " \"axes.spines.left\": True,\n", + " },\n", + ")" ] }, { @@ -689,15 +742,20 @@ ], "source": [ "from modules.MOL.mol_graph import draw_graph\n", + "\n", "x = df.iloc[0]\n", "x_mol = x[\"Metabolite\"].MOL\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", - "axs[0].set_title(x[\"NAME\"]+ \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(x[\"NAME\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", "\n", @@ -769,13 +827,15 @@ ], "source": [ "from modules.MOL.constants import DEFAULT_MODES\n", + "\n", "df[\"peak_matches\"] = df[\"Metabolite\"].apply(lambda x: getattr(x, \"peak_matches\"))\n", "df[\"num_peaks_matched\"] = df[\"peak_matches\"].apply(len)\n", "\n", + "\n", "def get_match_stats(matches, mode_count={m: 0 for m in DEFAULT_MODES}):\n", " num_unique, num_conflicts = 0, 0\n", " for mz, match_data in matches.items():\n", - " #candidates = match_data[\"fragments\"]\n", + " # candidates = match_data[\"fragments\"]\n", " ion_modes = match_data[\"ion_modes\"]\n", " if len(ion_modes) == 1:\n", " num_unique += 1\n", @@ -786,14 +846,15 @@ " return num_unique, num_conflicts, mode_count\n", "\n", "\n", - "\n", "df[\"match_stats\"] = df[\"peak_matches\"].apply(lambda x: get_match_stats(x))\n", "df[\"num_unique_peaks_matched\"] = df.apply(lambda x: x[\"match_stats\"][0], axis=1)\n", "df[\"num_conflicts_in_peak_matching\"] = df.apply(lambda x: x[\"match_stats\"][1], axis=1)\n", "df[\"match_mode_counts\"] = df.apply(lambda x: x[\"match_stats\"][2], axis=1)\n", - "u= df[\"num_unique_peaks_matched\"].sum() \n", - "s= df[\"num_conflicts_in_peak_matching\"].sum() \n", - "print(f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u+s):.02f} %))\")\n", + "u = df[\"num_unique_peaks_matched\"].sum()\n", + "s = df[\"num_conflicts_in_peak_matching\"].sum()\n", + "print(\n", + " f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))\"\n", + ")\n", "print(f\"Total number of conflicting peak to fragment matches: {s}\")" ] }, @@ -817,20 +878,33 @@ "fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", "fig.suptitle(f\"Identified peaks with fragment offset\")\n", - "#plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", - "sns.histplot(ax=axs[0],data=df, x=\"num_peaks_matched\", color=color_palette[0], edgecolor=\"black\", bins=range(0,20, 1))\n", - "#axs[0].set_ylim(-0.5, 10)\n", + "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", + "sns.histplot(\n", + " ax=axs[0],\n", + " data=df,\n", + " x=\"num_peaks_matched\",\n", + " color=color_palette[0],\n", + " edgecolor=\"black\",\n", + " bins=range(0, 20, 1),\n", + ")\n", + "# axs[0].set_ylim(-0.5, 10)\n", "axs[0].set_ylabel(\"peaks identified\")\n", "\n", "\n", - "sns.boxplot(ax=axs[1],data=df, y=\"num_unique_peaks_matched\", color=color_palette[1])\n", + "sns.boxplot(ax=axs[1], data=df, y=\"num_unique_peaks_matched\", color=color_palette[1])\n", "axs[1].set_ylim(-0.5, 15)\n", "axs[1].set_xlabel(\"unique matches\")\n", "axs[1].set_ylabel(\"\")\n", "\n", "\n", - "sns.histplot(ax=axs[2],data=df, x=\"num_conflicts_in_peak_matching\", color=color_palette[3], binwidth=1)\n", - "#axs[2].set_ylim(-0.5, 1000)\n", + "sns.histplot(\n", + " ax=axs[2],\n", + " data=df,\n", + " x=\"num_conflicts_in_peak_matching\",\n", + " color=color_palette[3],\n", + " binwidth=1,\n", + ")\n", + "# axs[2].set_ylim(-0.5, 1000)\n", "axs[2].set_xlabel(\"conflicts\")\n", "axs[2].set_ylabel(\"\")\n", "\n", @@ -858,14 +932,28 @@ "\n", "mode_counts = {m: 0 for m in DEFAULT_MODES}\n", "\n", + "\n", "def update_mode_counts(m):\n", " for mode in m.keys():\n", " mode_counts[mode] += m[mode]\n", "\n", + "\n", "df[\"match_mode_counts\"].apply(update_mode_counts)\n", "\n", - "sns.barplot(ax=axs[0], x=list(mode_counts.keys()), y=[mode_counts[k] for k in mode_counts.keys()], palette=color_palette, edgecolor=\"black\", linewidth=1.5)\n", - "axs[1].pie([mode_counts[k] for k in mode_counts.keys()], labels=list(mode_counts.keys()), colors=color_palette, wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5})\n", + "sns.barplot(\n", + " ax=axs[0],\n", + " x=list(mode_counts.keys()),\n", + " y=[mode_counts[k] for k in mode_counts.keys()],\n", + " palette=color_palette,\n", + " edgecolor=\"black\",\n", + " linewidth=1.5,\n", + ")\n", + "axs[1].pie(\n", + " [mode_counts[k] for k in mode_counts.keys()],\n", + " labels=list(mode_counts.keys()),\n", + " colors=color_palette,\n", + " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + ")\n", "\n", "plt.show()" ] @@ -904,12 +992,12 @@ "df.dropna(subset=[\"SMILES\"], inplace=True)\n", "df[\"in_casmi2016\"] = False\n", "\n", - "for i,d in df_cas.iterrows():\n", + "for i, d in df_cas.iterrows():\n", " m = d[\"Metabolite\"]\n", - " \n", - " for x,D in df.iterrows():\n", + "\n", + " for x, D in df.iterrows():\n", " M = D[\"Metabolite\"]\n", - " if (m == M):\n", + " if m == M:\n", " df.at[x, \"in_casmi2016\"] = True\n", "del df_cas" ] @@ -933,9 +1021,11 @@ } ], "source": [ - "\n", - "for i in range(0,6):\n", - " print(f\"Minimum {i} unique peaks identified (including precursors): \", (df[\"num_unique_peaks_matched\"] >= i).sum())\n" + "for i in range(0, 6):\n", + " print(\n", + " f\"Minimum {i} unique peaks identified (including precursors): \",\n", + " (df[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " )" ] }, { @@ -1034,15 +1124,15 @@ "save_df = False\n", "lib = f\"{home}/data/metabolites/MS_DIAL/\"\n", "name = \"ms_dial_filtered\"\n", - "date = \"XXX\" # \"mid_08_2023\" #\"mid_08_2023\" #\"07_2023\"\n", + "date = \"XXX\" # \"mid_08_2023\" #\"mid_08_2023\" #\"07_2023\"\n", "min_peaks = 5\n", "\n", "if save_df:\n", - " file = lib + name + \"_min\" + str(min_peaks) + \"_\" + date + \".csv\"\n", - " print(\"saving to \", file)\n", - " df[df[\"num_unique_peaks_matched\"] >= min_peaks].to_csv(file)\n", - " \n", - " #df.to_csv(lib + name + \"_all\" + \"_\" + date + \".csv\") #TODO HERE" + " file = lib + name + \"_min\" + str(min_peaks) + \"_\" + date + \".csv\"\n", + " print(\"saving to \", file)\n", + " df[df[\"num_unique_peaks_matched\"] >= min_peaks].to_csv(file)\n", + "\n", + " # df.to_csv(lib + name + \"_all\" + \"_\" + date + \".csv\") #TODO HERE" ] }, { @@ -1117,11 +1207,25 @@ } ], "source": [ - "\n", - "sns.histplot(df[~df[\"in_casmi2016\"]], x=\"RETENTIONTIME\", kde=True, binwidth=1, stat=\"density\", multiple=\"stack\")\n", - "sns.histplot(df[df[\"in_casmi2016\"]], x=\"RETENTIONTIME\", kde=True, binwidth=1, stat=\"density\", multiple=\"stack\", color=\"orange\")\n", + "sns.histplot(\n", + " df[~df[\"in_casmi2016\"]],\n", + " x=\"RETENTIONTIME\",\n", + " kde=True,\n", + " binwidth=1,\n", + " stat=\"density\",\n", + " multiple=\"stack\",\n", + ")\n", + "sns.histplot(\n", + " df[df[\"in_casmi2016\"]],\n", + " x=\"RETENTIONTIME\",\n", + " kde=True,\n", + " binwidth=1,\n", + " stat=\"density\",\n", + " multiple=\"stack\",\n", + " color=\"orange\",\n", + ")\n", "plt.legend(labels=[\"Non-Casmi\", \"Casmi2016\"])\n", - "plt.xlim([0,30])\n", + "plt.xlim([0, 30])\n", "plt.show()" ] }, @@ -1155,8 +1259,23 @@ ], "source": [ "df[\"CCS\"] = df[\"CCS\"].astype(float)\n", - "sns.histplot(df[~df[\"in_casmi2016\"]], x=\"CCS\", kde=True, binwidth=10, stat=\"density\", multiple=\"stack\")\n", - "sns.histplot(df[df[\"in_casmi2016\"]], x=\"CCS\", kde=True, binwidth=10, stat=\"density\", multiple=\"stack\", color=\"orange\")\n", + "sns.histplot(\n", + " df[~df[\"in_casmi2016\"]],\n", + " x=\"CCS\",\n", + " kde=True,\n", + " binwidth=10,\n", + " stat=\"density\",\n", + " multiple=\"stack\",\n", + ")\n", + "sns.histplot(\n", + " df[df[\"in_casmi2016\"]],\n", + " x=\"CCS\",\n", + " kde=True,\n", + " binwidth=10,\n", + " stat=\"density\",\n", + " multiple=\"stack\",\n", + " color=\"orange\",\n", + ")\n", "plt.legend(labels=[\"Non-Casmi\", \"Casmi2016\"])\n", "\n", "plt.show()" @@ -1179,10 +1298,17 @@ } ], "source": [ - "\n", "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", - "sns.histplot(ax=ax, data=df, x=\"RETENTIONTIME\", hue='origin', multiple=\"stack\", binwidth=1, stat=\"probability\")\n", + "sns.histplot(\n", + " ax=ax,\n", + " data=df,\n", + " x=\"RETENTIONTIME\",\n", + " hue=\"origin\",\n", + " multiple=\"stack\",\n", + " binwidth=1,\n", + " stat=\"probability\",\n", + ")\n", "plt.xlim([0, 20])\n", "plt.show()" ] @@ -1204,15 +1330,14 @@ } ], "source": [ - "\n", "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", - "sns.kdeplot(ax=ax, data=df, x=\"RETENTIONTIME\", hue='origin', multiple=\"fill\", common_norm=False)\n", + "sns.kdeplot(\n", + " ax=ax, data=df, x=\"RETENTIONTIME\", hue=\"origin\", multiple=\"fill\", common_norm=False\n", + ")\n", "ax.legend(bbox_to_anchor=(1.5, 0.8), labels=df[\"origin\"].unique())\n", "plt.xlim([0, 30])\n", - "plt.show()\n", - " \n", - " " + "plt.show()" ] }, { @@ -1232,10 +1357,11 @@ } ], "source": [ - "\n", "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", - "sns.kdeplot(ax=ax, data=df, x=\"RETENTIONTIME\", hue='origin', multiple=\"layer\", common_norm=False)\n", + "sns.kdeplot(\n", + " ax=ax, data=df, x=\"RETENTIONTIME\", hue=\"origin\", multiple=\"layer\", common_norm=False\n", + ")\n", "plt.legend(labels=df[\"origin\"].unique())\n", "plt.xlim([0, 30])\n", "plt.show()" @@ -1378,7 +1504,7 @@ } ], "source": [ - "#sns.catplot(data=df, x=\"origin\")\n", + "# sns.catplot(data=df, x=\"origin\")\n", "sns.histplot(data=df, x=\"origin\")" ] }, @@ -1409,8 +1535,15 @@ "fig, axs = plt.subplots(1, 1, figsize=(12.8, 6.4), sharey=False)\n", "\n", "\n", - "top_instrumenttypes = df['INSTRUMENTTYPE'].value_counts().head(6).index\n", - "sns.histplot(data=df[df['INSTRUMENTTYPE'].isin(top_instrumenttypes)], x=\"CE\", hue=\"INSTRUMENTTYPE\", multiple=\"stack\", binwidth=5, kde=False)\n", + "top_instrumenttypes = df[\"INSTRUMENTTYPE\"].value_counts().head(6).index\n", + "sns.histplot(\n", + " data=df[df[\"INSTRUMENTTYPE\"].isin(top_instrumenttypes)],\n", + " x=\"CE\",\n", + " hue=\"INSTRUMENTTYPE\",\n", + " multiple=\"stack\",\n", + " binwidth=5,\n", + " kde=False,\n", + ")\n", "plt.xlim([0, 150])\n", "plt.show()\n", "print(df.shape)" diff --git a/lib_loader/msnlib_loader.ipynb b/lib_loader/msnlib_loader.ipynb index d72ef57..8149513 100644 --- a/lib_loader/msnlib_loader.ipynb +++ b/lib_loader/msnlib_loader.ipynb @@ -15,7 +15,8 @@ ], "source": [ "import sys\n", - "print(f'Working with Python {sys.version}')\n", + "\n", + "print(f\"Working with Python {sys.version}\")\n", "\n", "import pandas as pd\n", "from rdkit import Chem\n", @@ -29,17 +30,18 @@ "import ast\n", "import numpy as np\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", + "RDLogger.DisableLog(\"rdApp.*\")\n", "\n", "\n", "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "import fiora.IO.mspReader as mspReader\n", "import fiora.IO.mgfReader as mgfReader\n", - "import fiora.visualization.spectrum_visualizer as sv\n" + "import fiora.visualization.spectrum_visualizer as sv" ] }, { @@ -74,7 +76,7 @@ ], "source": [ "version = \"v7\"\n", - "path: str = f\"{home}/data/metabolites/MSnLib/{version}/\" \n", + "path: str = f\"{home}/data/metabolites/MSnLib/{version}/\"\n", "\n", "dfs = []\n", "for filename in os.listdir(path):\n", @@ -119,7 +121,7 @@ " lambda x: [float(v) for v in x.strip(\"[]\").split(delim)] if \"[\" in x else [float(x)]\n", ")\n", "df[\"Num_steps\"] = df[\"CE_steps\"].apply(len)\n", - "df[\"CE\"] = df[\"CE_steps\"].apply(lambda x: sum(x) / len(x))\n" + "df[\"CE\"] = df[\"CE_steps\"].apply(lambda x: sum(x) / len(x))" ] }, { @@ -162,16 +164,33 @@ ], "source": [ "from fiora.visualization.define_colors import *\n", + "\n", "set_light_theme()\n", - "fig, axs = plt.subplots(1, 2, figsize=(15, 5), gridspec_kw={'width_ratios': [1, 1]}, sharey=True)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(15, 5), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=True\n", + ")\n", "\n", "\n", - "sns.histplot(ax=axs[0], data=df[df[\"SPECTYPE\"].isin([\"SINGLE_BEST_SCAN\", \"SAME_ENERGY\"])], x=\"CE\", hue=\"origin\", multiple=\"stack\", binwidth=2)\n", - "sns.histplot(ax=axs[1], data=df[df[\"SPECTYPE\"].isin([\"ALL_ENERGIES\", \"ALL_MSN_TO_PSEUDO_MS2\"])], x=\"CE\", hue=\"origin\", multiple=\"stack\", binwidth=2)\n", + "sns.histplot(\n", + " ax=axs[0],\n", + " data=df[df[\"SPECTYPE\"].isin([\"SINGLE_BEST_SCAN\", \"SAME_ENERGY\"])],\n", + " x=\"CE\",\n", + " hue=\"origin\",\n", + " multiple=\"stack\",\n", + " binwidth=2,\n", + ")\n", + "sns.histplot(\n", + " ax=axs[1],\n", + " data=df[df[\"SPECTYPE\"].isin([\"ALL_ENERGIES\", \"ALL_MSN_TO_PSEUDO_MS2\"])],\n", + " x=\"CE\",\n", + " hue=\"origin\",\n", + " multiple=\"stack\",\n", + " binwidth=2,\n", + ")\n", "axs[0].set_title(\"Single Energy\")\n", "axs[0].legend(\"\")\n", "axs[1].set_title(\"Multiple Energies (Average)\")\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -191,15 +210,33 @@ } ], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(10, 5), gridspec_kw={'width_ratios': [1.5, 1]}, sharey=True)\n", - "#plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", - "\n", - "sns.histplot(ax=axs[0], data=df[df[\"IONMODE\"] == \"Positive\"], x=\"ADDUCT\", palette=magma(7), hue=\"ADDUCT\", edgecolor=\"black\", stat=\"density\")\n", - "axs[0].tick_params(axis='x', rotation=60)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(10, 5), gridspec_kw={\"width_ratios\": [1.5, 1]}, sharey=True\n", + ")\n", + "# plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", + "\n", + "sns.histplot(\n", + " ax=axs[0],\n", + " data=df[df[\"IONMODE\"] == \"Positive\"],\n", + " x=\"ADDUCT\",\n", + " palette=magma(7),\n", + " hue=\"ADDUCT\",\n", + " edgecolor=\"black\",\n", + " stat=\"density\",\n", + ")\n", + "axs[0].tick_params(axis=\"x\", rotation=60)\n", "axs[0].set_xlabel(\"\")\n", - "sns.histplot(ax=axs[1], data=df[df[\"IONMODE\"] == \"Negative\"], x=\"ADDUCT\", palette=sns.color_palette(\"mako_r\", 4), hue=\"ADDUCT\", edgecolor=\"black\", stat=\"density\")\n", + "sns.histplot(\n", + " ax=axs[1],\n", + " data=df[df[\"IONMODE\"] == \"Negative\"],\n", + " x=\"ADDUCT\",\n", + " palette=sns.color_palette(\"mako_r\", 4),\n", + " hue=\"ADDUCT\",\n", + " edgecolor=\"black\",\n", + " stat=\"density\",\n", + ")\n", "axs[1].set_xlabel(\"\")\n", - "axs[1].tick_params(axis='x', rotation=60)\n", + "axs[1].tick_params(axis=\"x\", rotation=60)\n", "plt.show()" ] }, @@ -230,7 +267,7 @@ "if filter_spectype:\n", " df = df[df[\"SPECTYPE\"].isin(keep_spectypes)]\n", "\n", - "# Note that this early filter step speeds up subsequent operations, \n", + "# Note that this early filter step speeds up subsequent operations,\n", "# but one may consider to include stepped/merged spectra and specifically model CE steps." ] }, @@ -258,7 +295,9 @@ "df[\"ppm_peak_tolerance\"] = TOLERANCE\n", "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - "_ = df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(memory_safe=False)) # Set memory_safe=False if necessary" + "_ = df[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(memory_safe=False)\n", + ") # Set memory_safe=False if necessary" ] }, { @@ -279,13 +318,22 @@ "metadata": {}, "outputs": [], "source": [ - "h_plus = Chem.MolFromSmiles(\"[H+]\") #h proton\n", - "\n", - "constants.ADDUCT_WEIGHTS.update({\n", - " \"[M+2H]-\": Descriptors.ExactMolWt(h_plus) + 1 * Descriptors.ExactMolWt(Chem.MolFromSmiles(\"[H]\")), # 1 proton + 2 neutral hydrogens\n", - " \"[M+3H]-\": Descriptors.ExactMolWt(h_plus) + 2 * Descriptors.ExactMolWt(Chem.MolFromSmiles(\"[H]\")), # 1 proton + 2 neutral hydrogens\n", - " \n", - "})\n", + "h_plus = Chem.MolFromSmiles(\"[H+]\") # h proton\n", + "\n", + "constants.ADDUCT_WEIGHTS.update(\n", + " {\n", + " \"[M+2H]-\": Descriptors.ExactMolWt(h_plus)\n", + " + 1\n", + " * Descriptors.ExactMolWt(\n", + " Chem.MolFromSmiles(\"[H]\")\n", + " ), # 1 proton + 2 neutral hydrogens\n", + " \"[M+3H]-\": Descriptors.ExactMolWt(h_plus)\n", + " + 2\n", + " * Descriptors.ExactMolWt(\n", + " Chem.MolFromSmiles(\"[H]\")\n", + " ), # 1 proton + 2 neutral hydrogens\n", + " }\n", + ")\n", "mindex.create_fragmentation_trees()" ] }, @@ -303,7 +351,9 @@ } ], "source": [ - "list_of_mismatched_ids = mindex.add_fragmentation_trees_to_metabolite_list(df[\"Metabolite\"], graph_mismatch_policy=\"recompute\")\n", + "list_of_mismatched_ids = mindex.add_fragmentation_trees_to_metabolite_list(\n", + " df[\"Metabolite\"], graph_mismatch_policy=\"recompute\"\n", + ")\n", "print(f\"Total number of recomputed trees: {len(list_of_mismatched_ids)}\")" ] }, @@ -329,11 +379,14 @@ "df[\"loss_weight\"] = df[\"Metabolite\"].apply(lambda x: x.loss_weight)\n", "print(f\"Number of metabolites in index: {mindex.get_number_of_metabolites()}\")\n", "\n", + "\n", "def print_df_stats(df):\n", " num_spectra = df.shape[0]\n", " num_ids = len(df[\"group_id\"].unique())\n", - " \n", - " print(f\"Dataframe stats: {num_spectra} spectra covering {num_ids} unique structures\")\n", + "\n", + " print(\n", + " f\"Dataframe stats: {num_spectra} spectra covering {num_ids} unique structures\"\n", + " )\n", "\n", "\n", "print_df_stats(df)" @@ -345,7 +398,15 @@ "metadata": {}, "outputs": [], "source": [ - "_ = df.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE, match_stats_only=True), axis=1)" + "_ = df.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"],\n", + " x[\"peaks\"][\"intensity\"],\n", + " tolerance=TOLERANCE,\n", + " match_stats_only=True,\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -366,10 +427,19 @@ ], "source": [ "df[\"coverage\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - "sns.histplot(data=df[df[\"SPECTYPE\"].isin([\"SINGLE_BEST_SCAN\", \"SAME_ENERGY\"])], x=\"coverage\", hue=\"CE\", palette=\"magma_r\", edgecolor=\"black\", multiple=\"stack\", stat=\"density\", hue_norm=(df[\"CE\"].min(), df[\"CE\"].max()))\n", + "sns.histplot(\n", + " data=df[df[\"SPECTYPE\"].isin([\"SINGLE_BEST_SCAN\", \"SAME_ENERGY\"])],\n", + " x=\"coverage\",\n", + " hue=\"CE\",\n", + " palette=\"magma_r\",\n", + " edgecolor=\"black\",\n", + " multiple=\"stack\",\n", + " stat=\"density\",\n", + " hue_norm=(df[\"CE\"].min(), df[\"CE\"].max()),\n", + ")\n", "plt.xlabel(\"Peak intensity covered by single fragmentation events\")\n", "\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -390,9 +460,18 @@ ], "source": [ "df[\"Num peaks\"] = df[\"Num peaks\"].astype(int)\n", - "sns.histplot(df, x=\"Num peaks\", color=magma(6)[2], multiple=\"stack\", stat=\"density\", binwidth=5, edgecolor=\"white\", linewidth=0.5)\n", + "sns.histplot(\n", + " df,\n", + " x=\"Num peaks\",\n", + " color=magma(6)[2],\n", + " multiple=\"stack\",\n", + " stat=\"density\",\n", + " binwidth=5,\n", + " edgecolor=\"white\",\n", + " linewidth=0.5,\n", + ")\n", "plt.xlabel(\"Num of peaks\")\n", - "plt.xlim(0,250)\n", + "plt.xlim(0, 250)\n", "plt.show()" ] }, @@ -443,15 +522,15 @@ "from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder\n", "\n", "metadata_key_map = {\n", - " \"name\": \"NAME\",\n", - " \"collision_energy\": \"CE\", \n", - " \"instrument\": \"instrument\",\n", - " \"ionization\": \"ionization\",\n", - " \"precursor_mz\": \"PEPMASS\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RTINSECONDS\",\n", - " \"ce_steps\": \"CE_steps\"\n", - " }\n", + " \"name\": \"NAME\",\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"instrument\",\n", + " \"ionization\": \"ionization\",\n", + " \"precursor_mz\": \"PEPMASS\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RTINSECONDS\",\n", + " \"ce_steps\": \"CE_steps\",\n", + "}\n", "\n", "filter_spectra = True\n", "CE_upper_limit = 100.0\n", @@ -461,27 +540,48 @@ "\n", "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", - "setup_encoder = CovariateFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\"])\n", - "rt_encoder = CovariateFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"])\n", + "setup_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " ]\n", + ")\n", + "rt_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"]\n", + ")\n", "\n", "if filter_spectra:\n", - " setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - " setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - " rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", + " setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + " setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + " rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", "\n", - "df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + "df[\"summary\"] = df.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", + ")\n", + "df.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "\n", "if filter_spectra:\n", - " df = df[df['ADDUCT'].isin(allowed_precursor_modes)]\n", + " df = df[df[\"ADDUCT\"].isin(allowed_precursor_modes)]\n", " num_ori = df.shape[0]\n", - " correct_energy = df[\"Metabolite\"].apply(lambda x: x.metadata[\"collision_energy\"] <= CE_upper_limit and x.metadata[\"collision_energy\"] > 1) \n", + " correct_energy = df[\"Metabolite\"].apply(\n", + " lambda x: (\n", + " x.metadata[\"collision_energy\"] <= CE_upper_limit\n", + " and x.metadata[\"collision_energy\"] > 1\n", + " )\n", + " )\n", " df = df[correct_energy]\n", - " correct_weight = df[\"Metabolite\"].apply(lambda x: x.metadata[\"molecular_weight\"] <= weight_upper_limit)\n", - " df = df[correct_weight] \n", + " correct_weight = df[\"Metabolite\"].apply(\n", + " lambda x: x.metadata[\"molecular_weight\"] <= weight_upper_limit\n", + " )\n", + " df = df[correct_weight]\n", " print(f\"Filtering spectra ({num_ori}) down to {df.shape[0]}\")\n", - " #df = df[df[\"SPECTYPE\"] != \"ALL_ENERGIES\"]\n", - " #print(df[\"Precursor_type\"].value_counts())\n" + " # df = df[df[\"SPECTYPE\"] != \"ALL_ENERGIES\"]\n", + " # print(df[\"Precursor_type\"].value_counts())" ] }, { @@ -535,7 +635,10 @@ "print(\"Num values in s_single_best:\", len(s_single_best))\n", "print(\"Num values in s_single:\", len(s_single))\n", "print(\"Num values in s_same:\", len(s_same))\n", - "print(\"Num values in combined single and same:\", len(s_single_best.union(s_same).union(s_single)))\n", + "print(\n", + " \"Num values in combined single and same:\",\n", + " len(s_single_best.union(s_same).union(s_single)),\n", + ")\n", "print(\"Num values in s_all:\", len(s_all))\n", "print(\"Num values in s_pseudo:\", len(s_pseudo))" ] @@ -596,7 +699,14 @@ "\n", "graph = df[\"Metabolite\"].iloc[0].Graph\n", "plt.figure(figsize=(10, 10))\n", - "nx.draw(graph, with_labels=True, node_color='lightblue', edge_color='gray', node_size=500, font_size=10)\n", + "nx.draw(\n", + " graph,\n", + " with_labels=True,\n", + " node_color=\"lightblue\",\n", + " edge_color=\"gray\",\n", + " node_size=500,\n", + " font_size=10,\n", + ")\n", "plt.show()" ] }, @@ -733,7 +843,7 @@ } ], "source": [ - "raise KeyboardInterrupt() #TODO continue here" + "raise KeyboardInterrupt() # TODO continue here" ] }, { @@ -757,7 +867,9 @@ "outputs": [], "source": [ "# TODO run this\n", - "ring_condition = df[\"Metabolite\"].apply(lambda x: (x.match_stats[\"coverage\"] < 0.5) & bool(x.ring_proportion > 0.8))\n", + "ring_condition = df[\"Metabolite\"].apply(\n", + " lambda x: (x.match_stats[\"coverage\"] < 0.5) & bool(x.ring_proportion > 0.8)\n", + ")\n", "df_rings = df[ring_condition]\n", "\n", "# TODO continue" @@ -793,7 +905,7 @@ "source": [ "save_rings: bool = False\n", "if save_rings:\n", - " path: str = f'{home}/data/metabolites/preprocessed/rings_msnlib.csv'\n", + " path: str = f\"{home}/data/metabolites/preprocessed/rings_msnlib.csv\"\n", " print(f\"Saving ring dataframe to {path}\")\n", " df_rings.to_csv(path)" ] @@ -816,21 +928,21 @@ "from typing import Dict\n", "\n", "# Hard filter conditions that must be fulfilled\n", - "hard_filters : Dict[str, int] = {\n", + "hard_filters: Dict[str, int] = {\n", " \"min_peaks\": 2,\n", " \"min_coverage\": 0.5,\n", " \"max_precursor_intensity\": 0.9,\n", "}\n", "\n", "# Soft conditions where at least one must be met\n", - "soft_filters : Dict[str, int] = {\n", + "soft_filters: Dict[str, int] = {\n", " \"desired_peaks\": 4,\n", " \"desired_coverage\": 0.75,\n", - " \"desired_peak_percentage\": 0.5, # Proportion of peaks covered by the fragmentation\n", + " \"desired_peak_percentage\": 0.5, # Proportion of peaks covered by the fragmentation\n", "}\n", "\n", - "hard_filters_drops : Dict[str, list] = {key: [] for key in hard_filters.keys()}\n", - "soft_filters_keeps : Dict[str, list] = {key: [] for key in soft_filters.keys()}" + "hard_filters_drops: Dict[str, list] = {key: [] for key in hard_filters.keys()}\n", + "soft_filters_keeps: Dict[str, list] = {key: [] for key in soft_filters.keys()}" ] }, { @@ -906,7 +1018,7 @@ "key_to_label = {\n", " \"min_peaks\": \"Min Peaks\",\n", " \"min_coverage\": \"Min Coverage\",\n", - " \"max_precursor_intensity\": \"Max Precursor Intensity\"\n", + " \"max_precursor_intensity\": \"Max Precursor Intensity\",\n", "}\n", "labels = [key_to_label[key] for key in hard_filters_drops.keys()]\n", "\n", @@ -918,8 +1030,7 @@ "for subset in venn.patches:\n", " if subset: # Check if the subset exists (not None)\n", " subset.set_linewidth(2) # Set line thickness\n", - " subset.set_edgecolor('black') # Set edge color to black\n", - "\n", + " subset.set_edgecolor(\"black\") # Set edge color to black\n", "\n", "\n", "# Add a title\n", @@ -992,7 +1103,7 @@ "import matplotlib.pyplot as plt\n", "\n", "# Convert spectral indices to group IDs\n", - "group_ids = df['group_id'].unique()\n", + "group_ids = df[\"group_id\"].unique()\n", "\n", "# Identify compounds completely removed by hard filters\n", "hard_filter_indices = set()\n", @@ -1000,15 +1111,20 @@ " hard_filter_indices.update(indices)\n", "\n", "compounds_removed_by_hard_filters = {\n", - " group_id for group_id in group_ids\n", - " if all(idx in hard_filter_indices for idx in df[df['group_id'] == group_id].index)\n", + " group_id\n", + " for group_id in group_ids\n", + " if all(idx in hard_filter_indices for idx in df[df[\"group_id\"] == group_id].index)\n", "}\n", "\n", "# Identify compounds removed by soft filters (or hard filters)\n", "soft_filter_indices = set(drop_indices) - hard_filter_indices\n", "compounds_removed_by_soft_filters = {\n", - " group_id for group_id in group_ids\n", - " if all(idx in hard_filter_indices or idx in soft_filter_indices for idx in df[df['group_id'] == group_id].index)\n", + " group_id\n", + " for group_id in group_ids\n", + " if all(\n", + " idx in hard_filter_indices or idx in soft_filter_indices\n", + " for idx in df[df[\"group_id\"] == group_id].index\n", + " )\n", "}\n", "\n", "# All group IDs\n", @@ -1017,15 +1133,19 @@ "# Create the Venn diagram\n", "plt.figure(figsize=(8, 8))\n", "venn = venn3(\n", - " [all_compounds, compounds_removed_by_hard_filters, compounds_removed_by_soft_filters],\n", - " ('All Compounds', 'Removed By Hard Filter', 'Removed By Soft Filters')\n", + " [\n", + " all_compounds,\n", + " compounds_removed_by_hard_filters,\n", + " compounds_removed_by_soft_filters,\n", + " ],\n", + " (\"All Compounds\", \"Removed By Hard Filter\", \"Removed By Soft Filters\"),\n", ")\n", "\n", "# Customize the Venn diagram with thicker lines and black edges\n", "for subset in venn.patches:\n", " if subset: # Check if the subset exists (not None)\n", " subset.set_linewidth(2) # Set line thickness\n", - " subset.set_edgecolor('black') # Set edge color to black\n", + " subset.set_edgecolor(\"black\") # Set edge color to black\n", "\n", "# Add a title\n", "plt.title(\"Overlap of Compounds Removed by Hard and Soft Filters\")\n", @@ -1086,7 +1206,16 @@ } ], "source": [ - "sns.histplot(df[df[\"SPECTYPE\"].isin([\"SAME_ENERGY\", \"SINGLE_BEST_SCAN\"], )], x= \"coverage\", hue=\"CE\", multiple=\"stack\")" + "sns.histplot(\n", + " df[\n", + " df[\"SPECTYPE\"].isin(\n", + " [\"SAME_ENERGY\", \"SINGLE_BEST_SCAN\"],\n", + " )\n", + " ],\n", + " x=\"coverage\",\n", + " hue=\"CE\",\n", + " multiple=\"stack\",\n", + ")" ] }, { @@ -1120,19 +1249,23 @@ "from collections import defaultdict\n", "from fiora.MOL.constants import DEFAULT_MODES\n", "\n", + "\n", "def count_ion_mode_matches(df):\n", " ion_mode_counts = defaultdict(float)\n", "\n", " for metabolite in df[\"Metabolite\"]:\n", " for peak_data in metabolite.peak_matches.values():\n", - " ion_modes = peak_data['ion_modes']\n", + " ion_modes = peak_data[\"ion_modes\"]\n", " num_modes = len(ion_modes)\n", " for mode, _ in ion_modes: # mode is the ion mode string\n", " if mode in DEFAULT_MODES:\n", - " ion_mode_counts[mode] += 1 / num_modes # Divide count by number of modes\n", + " ion_mode_counts[mode] += (\n", + " 1 / num_modes\n", + " ) # Divide count by number of modes\n", "\n", " return dict(ion_mode_counts)\n", "\n", + "\n", "ion_mode_counts = count_ion_mode_matches(df)\n", "ion_mode_df = pd.DataFrame(list(ion_mode_counts.items()), columns=[\"Ion Mode\", \"Count\"])\n", "\n", @@ -1143,13 +1276,13 @@ "axes[0].set_title(\"Ion Mode Counts (Bar Plot)\")\n", "axes[0].set_xlabel(\"Ion Mode\")\n", "axes[0].set_ylabel(\"Count\")\n", - "axes[0].tick_params(axis='x', rotation=45)\n", + "axes[0].tick_params(axis=\"x\", rotation=45)\n", "axes[1].pie(\n", " ion_mode_df[\"Count\"],\n", " labels=ion_mode_df[\"Ion Mode\"],\n", - " autopct='%1.1f%%',\n", + " autopct=\"%1.1f%%\",\n", " startangle=90,\n", - " colors=sns.color_palette(\"viridis\", len(ion_mode_df))\n", + " colors=sns.color_palette(\"viridis\", len(ion_mode_df)),\n", ")\n", "axes[1].set_title(\"Ion Mode Distribution (Pie Chart)\")\n", "\n", @@ -1181,15 +1314,34 @@ "L = LibraryLoader()\n", "casmi16_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv\"\n", "casmi22_path = f\"{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv\"\n", - "df_merged = L.load_from_csv(f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\")\n", + "df_merged = L.load_from_csv(\n", + " f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\"\n", + ")\n", "df_cas = pd.read_csv(casmi16_path, index_col=[0], low_memory=False)\n", "df_cas22 = pd.read_csv(casmi22_path, index_col=[0], low_memory=False)\n", - "df_cast = pd.read_csv(f\"{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv\", index_col=[0], low_memory=False)\n", + "df_cast = pd.read_csv(\n", + " f\"{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv\",\n", + " index_col=[0],\n", + " low_memory=False,\n", + ")\n", "\n", "other_dfs = {\n", - " \"train\": df_merged[df_merged[\"dataset\"] == \"training\"].drop_duplicates(subset=[\"group_id\"]),\n", - " \"val\": df_merged[df_merged[\"dataset\"] == \"validation\"].drop_duplicates(subset=[\"group_id\"]),\n", - " \"test\": pd.concat([df_merged[df_merged[\"dataset\"] == \"test\"].drop_duplicates(subset=[\"group_id\"]), df_cas, df_cast, df_cas22]).drop_duplicates(subset=[\"SMILES\"]),\n", + " \"train\": df_merged[df_merged[\"dataset\"] == \"training\"].drop_duplicates(\n", + " subset=[\"group_id\"]\n", + " ),\n", + " \"val\": df_merged[df_merged[\"dataset\"] == \"validation\"].drop_duplicates(\n", + " subset=[\"group_id\"]\n", + " ),\n", + " \"test\": pd.concat(\n", + " [\n", + " df_merged[df_merged[\"dataset\"] == \"test\"].drop_duplicates(\n", + " subset=[\"group_id\"]\n", + " ),\n", + " df_cas,\n", + " df_cast,\n", + " df_cas22,\n", + " ]\n", + " ).drop_duplicates(subset=[\"SMILES\"]),\n", "}" ] }, @@ -1199,17 +1351,17 @@ "metadata": {}, "outputs": [], "source": [ - "lookup_table = {\n", - " \"train\": set(),\n", - " \"val\": set(),\n", - " \"test\": set()\n", - "}\n", + "lookup_table = {\"train\": set(), \"val\": set(), \"test\": set()}\n", "for key, df_x in other_dfs.items():\n", - " df_x[\"Metabolite\"] = df_x[\"SMILES\"].apply(Metabolite)\n", - " \n", - " for i, data in df_x.iterrows():\n", - " lookup_table[key].add((data[\"Metabolite\"].ExactMolWeight, data[\"Metabolite\"].morganFingerCountOnes))\n", - " " + " df_x[\"Metabolite\"] = df_x[\"SMILES\"].apply(Metabolite)\n", + "\n", + " for i, data in df_x.iterrows():\n", + " lookup_table[key].add(\n", + " (\n", + " data[\"Metabolite\"].ExactMolWeight,\n", + " data[\"Metabolite\"].morganFingerCountOnes,\n", + " )\n", + " )" ] }, { @@ -1225,7 +1377,7 @@ " metabolite: Metabolite = df[df[\"group_id\"] == id].iloc[0][\"Metabolite\"]\n", " fast_identifiers = (metabolite.ExactMolWeight, metabolite.morganFingerCountOnes)\n", " found_match = False\n", - " \n", + "\n", " if fast_identifiers in lookup_table[\"train\"]:\n", " for i, data in other_dfs[\"train\"].iterrows():\n", " other_metabolite = data[\"Metabolite\"]\n", @@ -1240,11 +1392,11 @@ " val.append(id)\n", " break\n", " if not found_match and fast_identifiers in lookup_table[\"test\"]:\n", - " for i, data in other_dfs[\"test\"].iterrows():\n", - " other_metabolite = data[\"Metabolite\"]\n", - " if metabolite == other_metabolite:\n", - " test.append(id)\n", - " break" + " for i, data in other_dfs[\"test\"].iterrows():\n", + " other_metabolite = data[\"Metabolite\"]\n", + " if metabolite == other_metabolite:\n", + " test.append(id)\n", + " break" ] }, { @@ -1262,7 +1414,9 @@ } ], "source": [ - "print(f\"Preset compounds assigned to datasplits: {len(train)=} {len(val)=} {len(test)=}\")\n", + "print(\n", + " f\"Preset compounds assigned to datasplits: {len(train)=} {len(val)=} {len(test)=}\"\n", + ")\n", "print(f\"{train[:5]=} {val[-5:]=} {test[5:10]=}\")" ] }, @@ -1274,12 +1428,17 @@ "source": [ "from sklearn.model_selection import train_test_split\n", "\n", + "\n", "def train_val_test_split(keys, test_size=0.1, val_size=0.1, rseed=42):\n", - " temp_keys, test_keys = train_test_split(keys, test_size=test_size, random_state=rseed)\n", + " temp_keys, test_keys = train_test_split(\n", + " keys, test_size=test_size, random_state=rseed\n", + " )\n", " adjusted_val_size = val_size / (1 - test_size)\n", - " train_keys, val_keys = train_test_split(temp_keys, test_size=adjusted_val_size, random_state=rseed)\n", - " \n", - " return train_keys, val_keys, test_keys\n" + " train_keys, val_keys = train_test_split(\n", + " temp_keys, test_size=adjusted_val_size, random_state=rseed\n", + " )\n", + "\n", + " return train_keys, val_keys, test_keys" ] }, { @@ -1301,13 +1460,25 @@ "\n", "test_new_frac = test_size_remaining / num_unassigned\n", "val_new_frac = val_size_remaining / num_unassigned\n", - "train_add, val_add, test_add = train_val_test_split(unassigned_keys, test_size=test_new_frac, val_size=val_new_frac)\n", + "train_add, val_add, test_add = train_val_test_split(\n", + " unassigned_keys, test_size=test_new_frac, val_size=val_new_frac\n", + ")\n", "train = np.concatenate((np.array(train), train_add))\n", "val = np.concatenate((np.array(val), val_add))\n", "test = np.concatenate((np.array(test), test_add))\n", "\n", "\n", - "df[\"dataset\"] = df[\"group_id\"].apply(lambda x: 'training' if x in train else 'validation' if x in val else 'test' if x in test else 'VALUE ERROR')\n", + "df[\"dataset\"] = df[\"group_id\"].apply(\n", + " lambda x: (\n", + " \"training\"\n", + " if x in train\n", + " else \"validation\"\n", + " if x in val\n", + " else \"test\"\n", + " if x in test\n", + " else \"VALUE ERROR\"\n", + " )\n", + ")\n", "df[\"datasplit\"] = df[\"dataset\"]" ] }, @@ -1389,7 +1560,7 @@ "source": [ "save_df: bool = False\n", "if save_df:\n", - " path: str = f'{home}/data/metabolites/preprocessed/datasplits_msnlib_v7_Sep25.csv' # Save with merge spectra (4)\n", + " path: str = f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_v7_Sep25.csv\" # Save with merge spectra (4)\n", " print(f\"Saving datasplits to {path}\")\n", " df.to_csv(path)" ] @@ -1424,15 +1595,23 @@ "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), sharey=False)\n", "\n", "fig.suptitle(f\"Identified peak-fragment matches and number conflicts\")\n", - "#plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", - "sns.histplot(ax=axs[0],data=df, x=\"num_peak_matches\", color=color_palette[0], edgecolor=\"black\", bins=range(0,20, 1))\n", - "#axs[0].set_ylim(-0.5, 10)\n", + "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", + "sns.histplot(\n", + " ax=axs[0],\n", + " data=df,\n", + " x=\"num_peak_matches\",\n", + " color=color_palette[0],\n", + " edgecolor=\"black\",\n", + " bins=range(0, 20, 1),\n", + ")\n", + "# axs[0].set_ylim(-0.5, 10)\n", "axs[0].set_ylabel(\"peaks identified\")\n", "\n", "\n", - "\n", - "sns.histplot(ax=axs[1],data=df, x=\"num_fragment_conflicts\", color=color_palette[3], binwidth=1)\n", - "#axs[2].set_ylim(-0.5, 1000)\n", + "sns.histplot(\n", + " ax=axs[1], data=df, x=\"num_fragment_conflicts\", color=color_palette[3], binwidth=1\n", + ")\n", + "# axs[2].set_ylim(-0.5, 1000)\n", "axs[1].set_xlabel(\"conflicts\")\n", "axs[1].set_ylabel(\"\")\n", "\n", @@ -1456,7 +1635,7 @@ } ], "source": [ - "#df[\"RETENTIONTIME\"] = df[\"RTINSECONDS\"].astype(float)\n", + "# df[\"RETENTIONTIME\"] = df[\"RTINSECONDS\"].astype(float)\n", "sns.displot(df, x=\"RTINSECONDS\", kde=True, binwidth=0.5, hue=\"ADDUCT\")\n", "plt.show()" ] @@ -1478,10 +1657,17 @@ } ], "source": [ - "\n", "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", - "sns.histplot(ax=ax, data=df, x=\"RTINSECONDS\", hue='origin', multiple=\"stack\", binwidth=1, stat=\"probability\")\n", + "sns.histplot(\n", + " ax=ax,\n", + " data=df,\n", + " x=\"RTINSECONDS\",\n", + " hue=\"origin\",\n", + " multiple=\"stack\",\n", + " binwidth=1,\n", + " stat=\"probability\",\n", + ")\n", "plt.show()" ] }, @@ -1502,14 +1688,13 @@ } ], "source": [ - "\n", "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", - "sns.kdeplot(ax=ax, data=df, x=\"RTINSECONDS\", hue='origin', multiple=\"fill\", common_norm=False)\n", + "sns.kdeplot(\n", + " ax=ax, data=df, x=\"RTINSECONDS\", hue=\"origin\", multiple=\"fill\", common_norm=False\n", + ")\n", "ax.legend(bbox_to_anchor=(1.5, 0.8), labels=df[\"origin\"].unique())\n", - "plt.show()\n", - " \n", - " " + "plt.show()" ] } ], diff --git a/lib_loader/nist_library_loader.ipynb b/lib_loader/nist_library_loader.ipynb index 893ad0f..8c66cd2 100644 --- a/lib_loader/nist_library_loader.ipynb +++ b/lib_loader/nist_library_loader.ipynb @@ -31,12 +31,14 @@ ], "source": [ "import sys\n", - "print(f'Working with Python {sys.version}')\n", + "\n", + "print(f\"Working with Python {sys.version}\")\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import importlib\n", - "#import swifter\n", + "\n", + "# import swifter\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import collections\n", @@ -52,11 +54,14 @@ "\n", "# Deep Learning\n", "import sklearn\n", - "#import spektral\n", + "\n", + "# import spektral\n", "from sklearn.model_selection import train_test_split\n", + "\n", "# Keras\n", "from sklearn.model_selection import train_test_split\n", - "#import stellargraph as sg\n", + "\n", + "# import stellargraph as sg\n", "from rdkit import RDLogger\n", "\n", "\n", @@ -64,16 +69,17 @@ "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "import fiora.IO.mspReader as mspReader\n", "import fiora.visualization.spectrum_visualizer as sv\n", "import fiora.IO.molReader as molReader\n", "\n", "\n", - "RDLogger.DisableLog('rdApp.*')\n", + "RDLogger.DisableLog(\"rdApp.*\")\n", "\n", "\n", - "caffeine_smiles = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'\n", + "caffeine_smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", "caffeine_mol = Chem.MolFromSmiles(caffeine_smiles)\n", "\n", "caffeine_mol" @@ -122,12 +128,14 @@ } ], "source": [ - "nist_msp = mspReader.read(library_directory + library_name + '.MSP')\n", + "nist_msp = mspReader.read(library_directory + library_name + \".MSP\")\n", "df_nist = pd.DataFrame(nist_msp)\n", "\n", - "#df_nist['mol'] = df_nist['SMILES'].apply(Chem.MolFromSmiles)\n", - "#df_nist.dropna(inplace=True)\n", - "print(f\"Spectral file loaded with {df_nist.shape[0]} entries and {df_nist.shape[1]} variables\")\n" + "# df_nist['mol'] = df_nist['SMILES'].apply(Chem.MolFromSmiles)\n", + "# df_nist.dropna(inplace=True)\n", + "print(\n", + " f\"Spectral file loaded with {df_nist.shape[0]} entries and {df_nist.shape[1]} variables\"\n", + ")" ] }, { @@ -167,10 +175,19 @@ "source": [ "# Define figure styles\n", "color_palette = sns.color_palette(\"magma_r\", 8)\n", - "sns.set_theme(style=\"whitegrid\",\n", - " rc={'axes.edgecolor': 'black', 'ytick.left': True, 'xtick.bottom': True, 'xtick.color': 'black',\n", - " \"axes.spines.bottom\": True, \"axes.spines.right\": True, \"axes.spines.top\": True,\n", - " \"axes.spines.left\": True})\n" + "sns.set_theme(\n", + " style=\"whitegrid\",\n", + " rc={\n", + " \"axes.edgecolor\": \"black\",\n", + " \"ytick.left\": True,\n", + " \"xtick.bottom\": True,\n", + " \"xtick.color\": \"black\",\n", + " \"axes.spines.bottom\": True,\n", + " \"axes.spines.right\": True,\n", + " \"axes.spines.top\": True,\n", + " \"axes.spines.left\": True,\n", + " },\n", + ")" ] }, { @@ -179,19 +196,28 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 2]}, sharey=True)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 2]}, sharey=True\n", + ")\n", "fig.set_tight_layout(False)\n", "for ax in axs:\n", - " ax.tick_params('x', labelrotation=45)\n", - "\n", - "sns.countplot(ax=axs[0], data=df_nist, x='Spectrum_type', edgecolor=\"black\", palette=color_palette)\n", - "sns.countplot(ax=axs[1], data=df_nist, x='Precursor_type', edgecolor=\"black\", palette=color_palette, order=df_nist['Precursor_type'].value_counts().iloc[:8].index)\n", + " ax.tick_params(\"x\", labelrotation=45)\n", + "\n", + "sns.countplot(\n", + " ax=axs[0], data=df_nist, x=\"Spectrum_type\", edgecolor=\"black\", palette=color_palette\n", + ")\n", + "sns.countplot(\n", + " ax=axs[1],\n", + " data=df_nist,\n", + " x=\"Precursor_type\",\n", + " edgecolor=\"black\",\n", + " palette=color_palette,\n", + " order=df_nist[\"Precursor_type\"].value_counts().iloc[:8].index,\n", + ")\n", "axs[0].set_ylim(0, 500000)\n", "axs[1].set_ylabel(\"\")\n", "\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -203,11 +229,13 @@ "# Filters\n", "df_nist = df_nist[df_nist[\"Spectrum_type\"] == \"MS2\"]\n", "target_precursor_type = [\"[M+H]+\", \"[M-H]-\", \"[M+H-H2O]+\", \"[M+Na]+\"]\n", - "df_nist = df_nist[df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in target_precursor_type)]\n", + "df_nist = df_nist[\n", + " df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in target_precursor_type)\n", + "]\n", "\n", "# Formats\n", - "df_nist['PrecursorMZ'] = df_nist[\"PrecursorMZ\"].astype('float')\n", - "df_nist['Num peaks'] = df_nist[\"Num peaks\"].astype('int')\n", + "df_nist[\"PrecursorMZ\"] = df_nist[\"PrecursorMZ\"].astype(\"float\")\n", + "df_nist[\"Num peaks\"] = df_nist[\"Num peaks\"].astype(\"int\")\n", "\n", "\n", "print(f\"Spectral file filtered down to {df_nist.shape[0]} entries\")" @@ -219,14 +247,25 @@ "metadata": {}, "outputs": [], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 2]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 2]}, sharey=False\n", + ")\n", "for ax in axs:\n", - " ax.tick_params('x', labelrotation=45)\n", - "\n", - "sns.boxplot(ax=axs[0], data=df_nist, y='PrecursorMZ', palette=color_palette, x=\"Precursor_type\")\n", - "sns.histplot(ax=axs[1], data=df_nist, x='Num peaks', color=color_palette[7], fill=True, edgecolor=\"black\")#, order=list(range(0,200)))\n", + " ax.tick_params(\"x\", labelrotation=45)\n", + "\n", + "sns.boxplot(\n", + " ax=axs[0], data=df_nist, y=\"PrecursorMZ\", palette=color_palette, x=\"Precursor_type\"\n", + ")\n", + "sns.histplot(\n", + " ax=axs[1],\n", + " data=df_nist,\n", + " x=\"Num peaks\",\n", + " color=color_palette[7],\n", + " fill=True,\n", + " edgecolor=\"black\",\n", + ") # , order=list(range(0,200)))\n", "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams['axes.edgecolor'] = 'black'\n", + "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", "axs[1].set_ylabel(\"\")\n", "axs[1].set_xlim([0, 100])\n", "axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", @@ -246,16 +285,20 @@ "x_mol = molReader.load_MOL(file)\n", "x_mol\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", - "axs[0].set_title(x[\"Name\"]+ \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", - "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n" + "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])" ] }, { @@ -266,25 +309,37 @@ }, "outputs": [], "source": [ + "# print(df_nist.loc[1474])\n", + "\n", + "print(\n", + " \"Reading structure information in MOL format from library files (this may take a while)\"\n", + ")\n", "\n", - "#print(df_nist.loc[1474])\n", "\n", - "print(\"Reading structure information in MOL format from library files (this may take a while)\")\n", "def fetch_mol(data):\n", - " file = library_directory + library_name + \".MOL/\" + \"S\" + str(data[\"CASNO\"]) + \".MOL\"\n", + " file = (\n", + " library_directory + library_name + \".MOL/\" + \"S\" + str(data[\"CASNO\"]) + \".MOL\"\n", + " )\n", " if not os.path.exists(file):\n", - " file = library_directory + library_name + \".MOL/\" + \"ID\" + str(data[\"ID\"]) + \".MOL\"\n", + " file = (\n", + " library_directory + library_name + \".MOL/\" + \"ID\" + str(data[\"ID\"]) + \".MOL\"\n", + " )\n", " return molReader.load_MOL(file)\n", "\n", - "df_nist= df_nist[~df_nist[\"InChIKey\"].isnull()] # Drop all without key (Not neccessarily neccesary)\n", + "\n", + "df_nist = df_nist[\n", + " ~df_nist[\"InChIKey\"].isnull()\n", + "] # Drop all without key (Not neccessarily neccesary)\n", "df_nist[\"MOL\"] = df_nist.apply(fetch_mol, axis=1)\n", - "print(f\"Successfully interpreted {sum(df_nist['MOL'].notna())} from {df_nist.shape[0]} entries. Dropping the rest.\")\n", + "print(\n", + " f\"Successfully interpreted {sum(df_nist['MOL'].notna())} from {df_nist.shape[0]} entries. Dropping the rest.\"\n", + ")\n", "\n", - "df_nist = df_nist[df_nist['MOL'].notna()]\n", + "df_nist = df_nist[df_nist[\"MOL\"].notna()]\n", "df_nist[\"SMILES\"] = df_nist[\"MOL\"].apply(Chem.MolToSmiles)\n", "df_nist[\"InChI\"] = df_nist[\"MOL\"].apply(Chem.MolToInchi)\n", "df_nist[\"K\"] = df_nist[\"MOL\"].apply(Chem.MolToInchiKey)\n", - "df_nist[\"ExactMolWeight\"] = df_nist[\"MOL\"].apply(Chem.Descriptors.ExactMolWt)\n" + "df_nist[\"ExactMolWeight\"] = df_nist[\"MOL\"].apply(Chem.Descriptors.ExactMolWt)" ] }, { @@ -311,10 +366,16 @@ "source": [ "correct_keys = df_nist.apply(lambda x: x[\"InChIKey\"] == x[\"K\"], axis=1)\n", "s = \"confirmed!\" if correct_keys.all() else \"not confirmed !! Attention!\"\n", - "print(f\"Confirming whether computed and provided InChI-Keys are correct. Result: {s} ({correct_keys.sum()/len(correct_keys):0.2f} correct)\")\n", - "half_keys = df_nist.apply(lambda x: x[\"InChIKey\"].split('-')[0] == x[\"K\"].split('-')[0], axis=1)\n", + "print(\n", + " f\"Confirming whether computed and provided InChI-Keys are correct. Result: {s} ({correct_keys.sum() / len(correct_keys):0.2f} correct)\"\n", + ")\n", + "half_keys = df_nist.apply(\n", + " lambda x: x[\"InChIKey\"].split(\"-\")[0] == x[\"K\"].split(\"-\")[0], axis=1\n", + ")\n", "s = \"confirmed!\" if half_keys.all() else \"not confirmed !! Attention!\"\n", - "print(f\"Checking if main layer InChI-Keys are correct. Result: {s} ({half_keys.sum()/len(half_keys):0.3f} correct)\")\n", + "print(\n", + " f\"Checking if main layer InChI-Keys are correct. Result: {s} ({half_keys.sum() / len(half_keys):0.3f} correct)\"\n", + ")\n", "\n", "print(\"Dropping all other.\")\n", "df_nist[\"matching_key\"] = df_nist.apply(lambda x: x[\"InChIKey\"] == x[\"K\"], axis=1)\n", @@ -362,12 +423,17 @@ "from modules.MOL.constants import ADDUCT_WEIGHTS\n", "\n", "\n", - "\n", "df_nist = df_nist[df_nist[\"Num peaks\"] > MIN_PEAKS]\n", - "df_nist = df_nist[df_nist[\"Num peaks\"] < MAX_PEAKS] #TODO WHY MAX CUTOFF: REMOVE!!\n", - "df_nist[\"theoretical_precursor_mz\"] = df_nist[\"ExactMolWeight\"] + df_nist[\"Precursor_type\"].map(ADDUCT_WEIGHTS)\n", - "df_nist = df_nist[df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in PRECURSOR_TYPES)]\n", - "df_nist[\"precursor_offset\"] = df_nist[\"PrecursorMZ\"] - df_nist[\"theoretical_precursor_mz\"]\n", + "df_nist = df_nist[df_nist[\"Num peaks\"] < MAX_PEAKS] # TODO WHY MAX CUTOFF: REMOVE!!\n", + "df_nist[\"theoretical_precursor_mz\"] = df_nist[\"ExactMolWeight\"] + df_nist[\n", + " \"Precursor_type\"\n", + "].map(ADDUCT_WEIGHTS)\n", + "df_nist = df_nist[\n", + " df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in PRECURSOR_TYPES)\n", + "]\n", + "df_nist[\"precursor_offset\"] = (\n", + " df_nist[\"PrecursorMZ\"] - df_nist[\"theoretical_precursor_mz\"]\n", + ")\n", "\n", "print(f\"Shape {df_nist.shape}\")" ] @@ -378,17 +444,32 @@ "metadata": {}, "outputs": [], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1.5]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1.5]}, sharey=False\n", + ")\n", "for ax in axs:\n", - " ax.tick_params('x', labelrotation=45)\n", - "\n", - "sns.scatterplot(ax=axs[0], data=df_nist,x=\"precursor_offset\", y='PrecursorMZ', palette=color_palette)\n", - "sns.histplot(ax=axs[1], data=df_nist, x='precursor_offset', color=color_palette[7], fill=True, edgecolor=\"black\")#, order=list(range(0,200)))\n", + " ax.tick_params(\"x\", labelrotation=45)\n", + "\n", + "sns.scatterplot(\n", + " ax=axs[0],\n", + " data=df_nist,\n", + " x=\"precursor_offset\",\n", + " y=\"PrecursorMZ\",\n", + " palette=color_palette,\n", + ")\n", + "sns.histplot(\n", + " ax=axs[1],\n", + " data=df_nist,\n", + " x=\"precursor_offset\",\n", + " color=color_palette[7],\n", + " fill=True,\n", + " edgecolor=\"black\",\n", + ") # , order=list(range(0,200)))\n", "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams['axes.edgecolor'] = 'black'\n", + "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", "axs[1].set_ylabel(\"\")\n", - "#axs[1].set_xlim([0, 100])\n", - "#axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", + "# axs[1].set_xlim([0, 100])\n", + "# axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", "\n", "plt.show()" ] @@ -442,27 +523,35 @@ "\n", "from modules.MOL.collision_energy import align_CE\n", "\n", - "df_nist[\"CE\"] = df_nist.apply(lambda x: align_CE(x[\"Collision_energy\"], x[\"theoretical_precursor_mz\"]), axis=1) #modules.MOL.collision_energy.align_CE) \n", + "df_nist[\"CE\"] = df_nist.apply(\n", + " lambda x: align_CE(x[\"Collision_energy\"], x[\"theoretical_precursor_mz\"]), axis=1\n", + ") # modules.MOL.collision_energy.align_CE)\n", "df_nist[\"CE_type\"] = df_nist[\"CE\"].apply(type)\n", - "df_nist[\"CE_derived_from_NCE\"] = df_nist[\"Collision_energy\"].apply(lambda x: \"%\" in str(x))\n", + "df_nist[\"CE_derived_from_NCE\"] = df_nist[\"Collision_energy\"].apply(\n", + " lambda x: \"%\" in str(x)\n", + ")\n", "# df_test = df_nist[df_nist[\"Collision_energy\"].apply(lambda x: \"%\" in str(x))][\"Collision_energy\"]\n", "# df_test = df_test.apply(lambda x: x.split('%')[0].strip().split(' ')[-1])\n", "# for d in df_test:\n", - "# try: \n", + "# try:\n", "# float(d)\n", "# except:\n", "# print(d)\n", "# TODO FIND MORE CE derived from different NCE types\n", "\n", - "print(\"Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)\")\n", + "print(\n", + " \"Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)\"\n", + ")\n", "print(df_nist[\"CE_type\"].value_counts())\n", "\n", "print(\"Removing all but absolute values\")\n", "df_nist = df_nist[df_nist[\"CE_type\"] == float]\n", "df_nist = df_nist[~df_nist[\"CE\"].isnull()]\n", - "#len(df_nist['CE'].unique())\n", + "# len(df_nist['CE'].unique())\n", "\n", - "print(f'Detected {len(df_nist[\"CE\"].unique())} unique collision energies in range from {np.min(df_nist[\"CE\"])} to {max(df_nist[\"CE\"])} eV')\n" + "print(\n", + " f\"Detected {len(df_nist['CE'].unique())} unique collision energies in range from {np.min(df_nist['CE'])} to {max(df_nist['CE'])} eV\"\n", + ")" ] }, { @@ -471,7 +560,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_nist[df_nist[\"Instrument_type\"] ==\"HCD\"][\"Collision_energy\"].value_counts()[:100]" + "df_nist[df_nist[\"Instrument_type\"] == \"HCD\"][\"Collision_energy\"].value_counts()[:100]" ] }, { @@ -481,13 +570,24 @@ "outputs": [], "source": [ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", - "#for ax in axs:\n", + "# for ax in axs:\n", "# ax.tick_params('x', labelrotation=45)\n", "\n", - "#sns.scatterplot(ax=axs[0], data=df_nist,x=\"precursor_offset\", y='PrecursorMZ', palette=color_palette)\n", - "sns.histplot(ax=ax, data=df_nist, x='CE', hue=\"CE_derived_from_NCE\", palette=[color_palette[4], color_palette[2]], multiple=\"stack\", fill=True, binwidth=2, edgecolor=\"black\", binrange=[0, 200])#, order=list(range(0,200)))\n", + "# sns.scatterplot(ax=axs[0], data=df_nist,x=\"precursor_offset\", y='PrecursorMZ', palette=color_palette)\n", + "sns.histplot(\n", + " ax=ax,\n", + " data=df_nist,\n", + " x=\"CE\",\n", + " hue=\"CE_derived_from_NCE\",\n", + " palette=[color_palette[4], color_palette[2]],\n", + " multiple=\"stack\",\n", + " fill=True,\n", + " binwidth=2,\n", + " edgecolor=\"black\",\n", + " binrange=[0, 200],\n", + ") # , order=list(range(0,200)))\n", "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams['axes.edgecolor'] = 'black'\n", + "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", "plt.show()\n", "print(f\"{df_nist.shape[0]} spectra remaining with aligned absolute collision energies\")" ] @@ -512,6 +612,7 @@ "%%capture\n", "from modules.MOL.Metabolite import Metabolite\n", "from modules.MOL.constants import PPM\n", + "\n", "TOLERANCE = 200 * PPM\n", "\n", "\n", @@ -519,7 +620,12 @@ "df_nist[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "df_nist[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes())\n", "df_nist[\"Metabolite\"].apply(lambda x: x.fragment_MOL())\n", - "df_nist.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE), axis=1)" + "df_nist.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -529,15 +635,20 @@ "outputs": [], "source": [ "from modules.MOL.mol_graph import draw_graph\n", + "\n", "x = df_nist.loc[EXAMPLE_ID]\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", - "axs[0].set_title(x[\"Name\"]+ \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", "\n", @@ -578,26 +689,29 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "x = df_nist.loc[EXAMPLE_ID]\n", "\n", "FT = x[\"Metabolite\"].fragmentation_tree\n", - "#frag.build_fragmentation_tree_by_rotatable_bond_breaks()\n", + "# frag.build_fragmentation_tree_by_rotatable_bond_breaks()\n", "print(FT)\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", "\n", "img = Chem.Draw.MolToImage(x[\"MOL\"], ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", - "axs[0].set_title(x[\"Name\"]+ \" structure:\\n\" + x[\"SMILES\"])\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + x[\"SMILES\"])\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n", "\n", "print(\"Matching peaks to fragments\")\n", - "print(x[\"Metabolite\"].peak_matches)\n" + "print(x[\"Metabolite\"].peak_matches)" ] }, { @@ -615,13 +729,17 @@ "outputs": [], "source": [ "from modules.MOL.constants import DEFAULT_MODES\n", - "df_nist[\"peak_matches\"] = df_nist[\"Metabolite\"].apply(lambda x: getattr(x, \"peak_matches\"))\n", + "\n", + "df_nist[\"peak_matches\"] = df_nist[\"Metabolite\"].apply(\n", + " lambda x: getattr(x, \"peak_matches\")\n", + ")\n", "df_nist[\"num_peaks_matched\"] = df_nist[\"peak_matches\"].apply(len)\n", "\n", + "\n", "def get_match_stats(matches, mode_count={m: 0 for m in DEFAULT_MODES}):\n", " num_unique, num_conflicts = 0, 0\n", " for mz, match_data in matches.items():\n", - " #candidates = match_data[\"fragments\"]\n", + " # candidates = match_data[\"fragments\"]\n", " ion_modes = match_data[\"ion_modes\"]\n", " if len(ion_modes) == 1:\n", " num_unique += 1\n", @@ -632,18 +750,22 @@ " return num_unique, num_conflicts, mode_count\n", "\n", "\n", - "\n", "df_nist[\"match_stats\"] = df_nist[\"peak_matches\"].apply(lambda x: get_match_stats(x))\n", - "df_nist[\"num_unique_peaks_matched\"] = df_nist.apply(lambda x: x[\"match_stats\"][0], axis=1)\n", - "df_nist[\"num_conflicts_in_peak_matching\"] = df_nist.apply(lambda x: x[\"match_stats\"][1], axis=1)\n", + "df_nist[\"num_unique_peaks_matched\"] = df_nist.apply(\n", + " lambda x: x[\"match_stats\"][0], axis=1\n", + ")\n", + "df_nist[\"num_conflicts_in_peak_matching\"] = df_nist.apply(\n", + " lambda x: x[\"match_stats\"][1], axis=1\n", + ")\n", "df_nist[\"match_mode_counts\"] = df_nist.apply(lambda x: x[\"match_stats\"][2], axis=1)\n", - "u= df_nist[\"num_unique_peaks_matched\"].sum() \n", - "s= df_nist[\"num_conflicts_in_peak_matching\"].sum() \n", - "print(f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u+s):.02f} %))\")\n", + "u = df_nist[\"num_unique_peaks_matched\"].sum()\n", + "s = df_nist[\"num_conflicts_in_peak_matching\"].sum()\n", + "print(\n", + " f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))\"\n", + ")\n", "print(f\"Total number of conflicting peak to fragment matches: {s}\")\n", "\n", - "df_nist.shape\n", - " " + "df_nist.shape" ] }, { @@ -655,19 +777,30 @@ "fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", "fig.suptitle(f\"Identified peaks with fragment offset\")\n", - "#plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", - "sns.histplot(ax=axs[0],data=df_nist, x=\"num_peaks_matched\", color=color_palette[0], edgecolor=\"black\", bins=range(0,20, 1))\n", - "#axs[0].set_ylim(-0.5, 10)\n", + "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", + "sns.histplot(\n", + " ax=axs[0],\n", + " data=df_nist,\n", + " x=\"num_peaks_matched\",\n", + " color=color_palette[0],\n", + " edgecolor=\"black\",\n", + " bins=range(0, 20, 1),\n", + ")\n", + "# axs[0].set_ylim(-0.5, 10)\n", "axs[0].set_ylabel(\"peaks identified\")\n", "\n", "\n", - "sns.boxplot(ax=axs[1],data=df_nist, y=\"num_unique_peaks_matched\", color=color_palette[1])\n", + "sns.boxplot(\n", + " ax=axs[1], data=df_nist, y=\"num_unique_peaks_matched\", color=color_palette[1]\n", + ")\n", "axs[1].set_ylim(-0.5, 15)\n", "axs[1].set_xlabel(\"unique matches\")\n", "axs[1].set_ylabel(\"\")\n", "\n", "\n", - "sns.boxplot(ax=axs[2],data=df_nist, y=\"num_conflicts_in_peak_matching\", color=color_palette[3])\n", + "sns.boxplot(\n", + " ax=axs[2], data=df_nist, y=\"num_conflicts_in_peak_matching\", color=color_palette[3]\n", + ")\n", "axs[2].set_ylim(-0.5, 15)\n", "axs[2].set_xlabel(\"conflicts\")\n", "axs[2].set_ylabel(\"\")\n", @@ -685,14 +818,28 @@ "\n", "mode_counts = {m: 0 for m in DEFAULT_MODES}\n", "\n", + "\n", "def update_mode_counts(m):\n", " for mode in m.keys():\n", " mode_counts[mode] += m[mode]\n", "\n", + "\n", "df_nist[\"match_mode_counts\"].apply(update_mode_counts)\n", "\n", - "sns.barplot(ax=axs[0], x=list(mode_counts.keys()), y=[mode_counts[k] for k in mode_counts.keys()], palette=color_palette, edgecolor=\"black\", linewidth=1.5)\n", - "axs[1].pie([mode_counts[k] for k in mode_counts.keys()], labels=list(mode_counts.keys()), colors=color_palette, wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5})\n", + "sns.barplot(\n", + " ax=axs[0],\n", + " x=list(mode_counts.keys()),\n", + " y=[mode_counts[k] for k in mode_counts.keys()],\n", + " palette=color_palette,\n", + " edgecolor=\"black\",\n", + " linewidth=1.5,\n", + ")\n", + "axs[1].pie(\n", + " [mode_counts[k] for k in mode_counts.keys()],\n", + " labels=list(mode_counts.keys()),\n", + " colors=color_palette,\n", + " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + ")\n", "\n", "plt.show()" ] @@ -703,9 +850,11 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "for i in range(0,6):\n", - " print(f\"Minimum {i} unique peaks identified (including precursors): \", (df_nist[\"num_unique_peaks_matched\"] >= i).sum())\n" + "for i in range(0, 6):\n", + " print(\n", + " f\"Minimum {i} unique peaks identified (including precursors): \",\n", + " (df_nist[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " )" ] }, { @@ -739,12 +888,17 @@ "from modules.MOL.constants import ADDUCT_WEIGHTS\n", "\n", "\n", - "\n", "df_minus = df_minus[df_minus[\"Num peaks\"] > MIN_PEAKS]\n", - "df_minus = df_minus[df_minus[\"Num peaks\"] < MAX_PEAKS] #TODO WHY MAX CUTOFF: REMOVE!!\n", - "df_minus[\"theoretical_precursor_mz\"] = df_minus[\"ExactMolWeight\"] + df_minus[\"Precursor_type\"].map(ADDUCT_WEIGHTS)\n", - "df_minus = df_minus[df_minus[\"Precursor_type\"].apply(lambda ptype: ptype in PRECURSOR_TYPES)]\n", - "df_minus[\"precursor_offset\"] = df_minus[\"PrecursorMZ\"] - df_minus[\"theoretical_precursor_mz\"]\n", + "df_minus = df_minus[df_minus[\"Num peaks\"] < MAX_PEAKS] # TODO WHY MAX CUTOFF: REMOVE!!\n", + "df_minus[\"theoretical_precursor_mz\"] = df_minus[\"ExactMolWeight\"] + df_minus[\n", + " \"Precursor_type\"\n", + "].map(ADDUCT_WEIGHTS)\n", + "df_minus = df_minus[\n", + " df_minus[\"Precursor_type\"].apply(lambda ptype: ptype in PRECURSOR_TYPES)\n", + "]\n", + "df_minus[\"precursor_offset\"] = (\n", + " df_minus[\"PrecursorMZ\"] - df_minus[\"theoretical_precursor_mz\"]\n", + ")\n", "\n", "print(f\"Shape {df_minus.shape}\")" ] @@ -766,26 +920,33 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "# TODO Use more Collision energy types. eg. ramps, stepped, resonant...\n", "\n", - "from modules.MOL.collision_energy import NCE_to_eV#align_CE,\n", + "from modules.MOL.collision_energy import NCE_to_eV # align_CE,\n", "\n", "\n", - "df_minus[\"CE\"] = df_minus.apply(lambda x: align_CE(x[\"Collision_energy\"], x[\"theoretical_precursor_mz\"]), axis=1) #modules.MOL.collision_energy.align_CE) \n", + "df_minus[\"CE\"] = df_minus.apply(\n", + " lambda x: align_CE(x[\"Collision_energy\"], x[\"theoretical_precursor_mz\"]), axis=1\n", + ") # modules.MOL.collision_energy.align_CE)\n", "df_minus[\"CE_type\"] = df_minus[\"CE\"].apply(type)\n", - "df_minus[\"CE_derived_from_NCE\"] = df_minus[\"Collision_energy\"].apply(lambda x: \"%\" in str(x))\n", + "df_minus[\"CE_derived_from_NCE\"] = df_minus[\"Collision_energy\"].apply(\n", + " lambda x: \"%\" in str(x)\n", + ")\n", "\n", "\n", - "print(\"Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)\")\n", + "print(\n", + " \"Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)\"\n", + ")\n", "print(df_minus[\"CE_type\"].value_counts())\n", "\n", "print(\"Removing all but absolute values\")\n", "df_minus = df_minus[df_minus[\"CE_type\"] == float]\n", "df_minus = df_minus[~df_minus[\"CE\"].isnull()]\n", - "#len(df_nist['CE'].unique())\n", + "# len(df_nist['CE'].unique())\n", "\n", - "print(f'Detected {len(df_minus[\"CE\"].unique())} unique collision energies in range from {np.min(df_minus[\"CE\"])} to {max(df_minus[\"CE\"])} eV')\n" + "print(\n", + " f\"Detected {len(df_minus['CE'].unique())} unique collision energies in range from {np.min(df_minus['CE'])} to {max(df_minus['CE'])} eV\"\n", + ")" ] }, { @@ -795,13 +956,24 @@ "outputs": [], "source": [ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", - "#for ax in axs:\n", + "# for ax in axs:\n", "# ax.tick_params('x', labelrotation=45)\n", "\n", - "#sns.scatterplot(ax=axs[0], data=df_nist,x=\"precursor_offset\", y='PrecursorMZ', palette=color_palette)\n", - "sns.histplot(ax=ax, data=df_minus, x='CE', hue=\"CE_derived_from_NCE\", palette=[color_palette[4], color_palette[2]], multiple=\"stack\", fill=True, binwidth=2, edgecolor=\"black\", binrange=[0, 200])#, order=list(range(0,200)))\n", + "# sns.scatterplot(ax=axs[0], data=df_nist,x=\"precursor_offset\", y='PrecursorMZ', palette=color_palette)\n", + "sns.histplot(\n", + " ax=ax,\n", + " data=df_minus,\n", + " x=\"CE\",\n", + " hue=\"CE_derived_from_NCE\",\n", + " palette=[color_palette[4], color_palette[2]],\n", + " multiple=\"stack\",\n", + " fill=True,\n", + " binwidth=2,\n", + " edgecolor=\"black\",\n", + " binrange=[0, 200],\n", + ") # , order=list(range(0,200)))\n", "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams['axes.edgecolor'] = 'black'\n", + "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", "plt.show()\n", "print(f\"{df_minus.shape[0]} spectra remaining with aligned absolute collision energies\")" ] @@ -828,9 +1000,14 @@ "df_minus[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "df_minus[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes())\n", "df_minus[\"Metabolite\"].apply(lambda x: x.fragment_MOL())\n", - "#df_minus[\"Metabolite\"].apply(lambda x: x.fragmentation_tree.set_fragment_modes(constants.NEGATIVE_MODES))\n", - "\n", - "df_minus.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE), axis=1)" + "# df_minus[\"Metabolite\"].apply(lambda x: x.fragmentation_tree.set_fragment_modes(constants.NEGATIVE_MODES))\n", + "\n", + "df_minus.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -875,24 +1052,28 @@ } ], "source": [ - "\n", - "\n", - "df_minus[\"peak_matches\"] = df_minus[\"Metabolite\"].apply(lambda x: getattr(x, \"peak_matches\"))\n", + "df_minus[\"peak_matches\"] = df_minus[\"Metabolite\"].apply(\n", + " lambda x: getattr(x, \"peak_matches\")\n", + ")\n", "df_minus[\"num_peaks_matched\"] = df_minus[\"peak_matches\"].apply(len)\n", "\n", "\n", - "\n", "df_minus[\"match_stats\"] = df_minus[\"peak_matches\"].apply(lambda x: get_match_stats(x))\n", - "df_minus[\"num_unique_peaks_matched\"] = df_minus.apply(lambda x: x[\"match_stats\"][0], axis=1)\n", - "df_minus[\"num_conflicts_in_peak_matching\"] = df_minus.apply(lambda x: x[\"match_stats\"][1], axis=1)\n", + "df_minus[\"num_unique_peaks_matched\"] = df_minus.apply(\n", + " lambda x: x[\"match_stats\"][0], axis=1\n", + ")\n", + "df_minus[\"num_conflicts_in_peak_matching\"] = df_minus.apply(\n", + " lambda x: x[\"match_stats\"][1], axis=1\n", + ")\n", "df_minus[\"match_mode_counts\"] = df_minus.apply(lambda x: x[\"match_stats\"][2], axis=1)\n", - "u= df_minus[\"num_unique_peaks_matched\"].sum() \n", - "s= df_minus[\"num_conflicts_in_peak_matching\"].sum() \n", - "print(f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u+s):.02f} %))\")\n", + "u = df_minus[\"num_unique_peaks_matched\"].sum()\n", + "s = df_minus[\"num_conflicts_in_peak_matching\"].sum()\n", + "print(\n", + " f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))\"\n", + ")\n", "print(f\"Total number of conflicting peak to fragment matches: {s}\")\n", "\n", - "df_minus.shape\n", - " " + "df_minus.shape" ] }, { @@ -904,19 +1085,30 @@ "fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", "fig.suptitle(f\"Identified peaks with fragment offset\")\n", - "#plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", - "sns.histplot(ax=axs[0],data=df_minus, x=\"num_peaks_matched\", color=color_palette[0], edgecolor=\"black\", bins=range(0,20, 1))\n", - "#axs[0].set_ylim(-0.5, 10)\n", + "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", + "sns.histplot(\n", + " ax=axs[0],\n", + " data=df_minus,\n", + " x=\"num_peaks_matched\",\n", + " color=color_palette[0],\n", + " edgecolor=\"black\",\n", + " bins=range(0, 20, 1),\n", + ")\n", + "# axs[0].set_ylim(-0.5, 10)\n", "axs[0].set_ylabel(\"peaks identified\")\n", "\n", "\n", - "sns.boxplot(ax=axs[1],data=df_minus, y=\"num_unique_peaks_matched\", color=color_palette[1])\n", + "sns.boxplot(\n", + " ax=axs[1], data=df_minus, y=\"num_unique_peaks_matched\", color=color_palette[1]\n", + ")\n", "axs[1].set_ylim(-0.5, 15)\n", "axs[1].set_xlabel(\"unique matches\")\n", "axs[1].set_ylabel(\"\")\n", "\n", "\n", - "sns.boxplot(ax=axs[2],data=df_minus, y=\"num_conflicts_in_peak_matching\", color=color_palette[3])\n", + "sns.boxplot(\n", + " ax=axs[2], data=df_minus, y=\"num_conflicts_in_peak_matching\", color=color_palette[3]\n", + ")\n", "axs[2].set_ylim(-0.5, 15)\n", "axs[2].set_xlabel(\"conflicts\")\n", "axs[2].set_ylabel(\"\")\n", @@ -934,16 +1126,32 @@ "\n", "mode_counts = {m: 0 for m in DEFAULT_MODES}\n", "\n", + "\n", "def update_mode_counts(m):\n", " for mode in m.keys():\n", " mode_counts[mode] += m[mode]\n", "\n", - "df_minus[\"match_mode_counts\"].apply(update_mode_counts)\n", "\n", - "mode_counts = dict((key.replace(\"]+\", \"]-\"), value) for (key, value) in mode_counts.items())\n", + "df_minus[\"match_mode_counts\"].apply(update_mode_counts)\n", "\n", - "sns.barplot(ax=axs[0], x=list(mode_counts.keys()), y=[mode_counts[k] for k in mode_counts.keys()], palette=color_palette, edgecolor=\"black\", linewidth=1.5)\n", - "axs[1].pie([mode_counts[k] for k in mode_counts.keys()], labels=list(mode_counts.keys()), colors=color_palette, wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5})\n", + "mode_counts = dict(\n", + " (key.replace(\"]+\", \"]-\"), value) for (key, value) in mode_counts.items()\n", + ")\n", + "\n", + "sns.barplot(\n", + " ax=axs[0],\n", + " x=list(mode_counts.keys()),\n", + " y=[mode_counts[k] for k in mode_counts.keys()],\n", + " palette=color_palette,\n", + " edgecolor=\"black\",\n", + " linewidth=1.5,\n", + ")\n", + "axs[1].pie(\n", + " [mode_counts[k] for k in mode_counts.keys()],\n", + " labels=list(mode_counts.keys()),\n", + " colors=color_palette,\n", + " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + ")\n", "\n", "plt.show()" ] @@ -954,9 +1162,11 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "for i in range(0,6):\n", - " print(f\"Minimum {i} unique peaks identified (including precursors): \", (df_minus[\"num_unique_peaks_matched\"] >= i).sum())\n" + "for i in range(0, 6):\n", + " print(\n", + " f\"Minimum {i} unique peaks identified (including precursors): \",\n", + " (df_minus[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " )" ] }, { @@ -984,9 +1194,11 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "for i in range(0,6):\n", - " print(f\"Minimum {i} unique peaks identified (including precursors): \", (df[\"num_unique_peaks_matched\"] >= i).sum())\n" + "for i in range(0, 6):\n", + " print(\n", + " f\"Minimum {i} unique peaks identified (including precursors): \",\n", + " (df[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " )" ] }, { @@ -997,25 +1209,62 @@ "source": [ "save_df = False\n", "name = \"nist_msms_filtered\"\n", - "date = \"XXX\" #\"07_2023\"\n", + "date = \"XXX\" # \"07_2023\"\n", "min_peaks = 1\n", "\n", "if save_df:\n", - " key_columns = ['Name', 'Synon', 'Notes', 'Precursor_type', 'Spectrum_type',\n", - " 'PrecursorMZ', 'Instrument_type', 'Instrument', 'Sample_inlet',\n", - " 'Ionization', 'Collision_energy', 'Ion_mode', 'Special_fragmentation',\n", - " 'InChIKey', 'Formula', 'MW', 'ExactMass', 'CASNO', 'NISTNO', 'ID',\n", - " 'Comment', 'Num peaks', 'peaks', 'Link', 'Related_CAS#',\n", - " 'Collision_gas', 'Pressure', 'In-source_voltage', 'msN_pathway', 'MOL',\n", - " 'SMILES', 'InChI', 'K', 'ExactMolWeight', 'matching_key',\n", - " 'theoretical_precursor_mz', 'precursor_offset', 'CE', 'CE_type', 'peak_matches',\n", - " 'num_peaks_matched', 'match_stats', 'num_unique_peaks_matched',\n", - " 'num_conflicts_in_peak_matching', 'match_mode_counts']\n", - " file = library_directory + name + \"_min\" + str(min_peaks) + \"_\" + date + \".csv\"\n", - " print(\"saving to \", file)\n", - " df[df[\"num_unique_peaks_matched\"] >= min_peaks][key_columns].to_csv(file)\n", - " \n", - " #df[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" + " key_columns = [\n", + " \"Name\",\n", + " \"Synon\",\n", + " \"Notes\",\n", + " \"Precursor_type\",\n", + " \"Spectrum_type\",\n", + " \"PrecursorMZ\",\n", + " \"Instrument_type\",\n", + " \"Instrument\",\n", + " \"Sample_inlet\",\n", + " \"Ionization\",\n", + " \"Collision_energy\",\n", + " \"Ion_mode\",\n", + " \"Special_fragmentation\",\n", + " \"InChIKey\",\n", + " \"Formula\",\n", + " \"MW\",\n", + " \"ExactMass\",\n", + " \"CASNO\",\n", + " \"NISTNO\",\n", + " \"ID\",\n", + " \"Comment\",\n", + " \"Num peaks\",\n", + " \"peaks\",\n", + " \"Link\",\n", + " \"Related_CAS#\",\n", + " \"Collision_gas\",\n", + " \"Pressure\",\n", + " \"In-source_voltage\",\n", + " \"msN_pathway\",\n", + " \"MOL\",\n", + " \"SMILES\",\n", + " \"InChI\",\n", + " \"K\",\n", + " \"ExactMolWeight\",\n", + " \"matching_key\",\n", + " \"theoretical_precursor_mz\",\n", + " \"precursor_offset\",\n", + " \"CE\",\n", + " \"CE_type\",\n", + " \"peak_matches\",\n", + " \"num_peaks_matched\",\n", + " \"match_stats\",\n", + " \"num_unique_peaks_matched\",\n", + " \"num_conflicts_in_peak_matching\",\n", + " \"match_mode_counts\",\n", + " ]\n", + " file = library_directory + name + \"_min\" + str(min_peaks) + \"_\" + date + \".csv\"\n", + " print(\"saving to \", file)\n", + " df[df[\"num_unique_peaks_matched\"] >= min_peaks][key_columns].to_csv(file)\n", + "\n", + " # df[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" ] } ], diff --git a/notebooks/break_tendency.ipynb b/notebooks/break_tendency.ipynb index 44003dd..2bf97c2 100644 --- a/notebooks/break_tendency.ipynb +++ b/notebooks/break_tendency.ipynb @@ -27,7 +27,7 @@ "import torch\n", "\n", "seed = 42\n", - "#torch.set_default_dtype(torch.float64)\n", + "# torch.set_default_dtype(torch.float64)\n", "torch.manual_seed(seed)\n", "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", @@ -41,22 +41,23 @@ "import seaborn as sns\n", "\n", "\n", - "\n", "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree \n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", "import fiora.visualization.spectrum_visualizer as sv\n", "\n", "from sklearn.metrics import r2_score\n", "import scipy\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f'Working with Python {sys.version}')\n" + "RDLogger.DisableLog(\"rdApp.*\")\n", + "\n", + "print(f\"Working with Python {sys.version}\")" ] }, { @@ -82,12 +83,15 @@ ], "source": [ "from typing import Literal\n", + "\n", "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\"] = \"NIST/MSDIAL\"\n", "print(f\"Preparing {lib} library\")\n", "\n", - "test_run = False # Default: False\n", + "test_run = False # Default: False\n", "if test_run:\n", - " print(\"+++ This is a test run with a small subset of data points. Results are not representative. +++\")" + " print(\n", + " \"+++ This is a test run with a small subset of data points. Results are not representative. +++\"\n", + " )" ] }, { @@ -98,44 +102,50 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\"\n", - " }\n", + " \"name\": \"Name\",\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"ionization\": \"Ionization\",\n", + " \"precursor_mz\": \"PrecursorMZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + " \"ccs\": \"CCS\",\n", + "}\n", "\n", "\n", "#\n", "# Load specified libraries and align metadata\n", "#\n", "\n", + "\n", "def load_nist():\n", " library_name = \"nist_msms_filteredall_07_2023\"\n", - " library_directory = f\"{home}/data/metabolites/NIST17/msp/nist_msms/\" \n", + " library_directory = f\"{home}/data/metabolites/NIST17/msp/nist_msms/\"\n", " L = LibraryLoader()\n", - " df = L.load_from_csv(library_directory + library_name + \".csv\") \n", + " df = L.load_from_csv(library_directory + library_name + \".csv\")\n", " df[\"RETENTIONTIME\"] = np.nan\n", " df[\"CCS\"] = np.nan\n", " df[\"PPM_num\"] = 50\n", " df[\"ppm_peak_tolerance\"] = df[\"PPM_num\"] * PPM\n", " df[\"lib\"] = \"NIST\"\n", " df[\"origin\"] = \"NIST\"\n", - " \n", + "\n", " return df\n", "\n", + "\n", "def load_msdial():\n", " library_name = \"ms_dial_filtered_all_mid_08_2023\"\n", " library_directory = f\"{home}/data/metabolites/MS_DIAL/\"\n", " L = LibraryLoader()\n", " df = L.load_from_csv(library_directory + library_name + \".csv\")\n", - " \n", + "\n", " orbitrap_nametags = [\"Orbitrap\"]\n", " qtof_nametags = [\"QTOF\", \"LC-ESI-QTOF\", \"ESI-QTOF\"]\n", - " df[\"Instrument_type\"] = df[\"INSTRUMENTTYPE\"].apply(lambda x: \"HCD\" if x in orbitrap_nametags else \"Q-TOF\" if x in qtof_nametags else x)\n", + " df[\"Instrument_type\"] = df[\"INSTRUMENTTYPE\"].apply(\n", + " lambda x: (\n", + " \"HCD\" if x in orbitrap_nametags else \"Q-TOF\" if x in qtof_nametags else x\n", + " )\n", + " )\n", " df[\"Ionization\"] = \"ESI\"\n", " df[\"original_RT\"] = df[\"RETENTIONTIME\"].astype(float)\n", " df[\"RETENTIONTIME\"] = df[\"RETENTIONTIME\"].astype(float)\n", @@ -146,30 +156,41 @@ " df[\"PPM_num\"] = 50\n", " df[\"ppm_peak_tolerance\"] = df[\"PPM_num\"] * PPM\n", " df[\"lib\"] = \"MSDIAL\"\n", - " \n", + "\n", " # Filter out retention times using other phase types (e.g. HILIC) or of unknown/heterogeneous souces\n", - " #df[df[\"origin\"] == \"MassBank High Quality Mass Spectral Database\"][\"RETENTIONTIME\"] = np.nan\n", - " #df[df[\"origin\"] == \"Fiehn Lab HILIC Library\"][\"RETENTIONTIME\"] = np.nan\n", - " bad_RT_libs = [\"MassBank High Quality Mass Spectral Database\", \"Fiehn Lab HILIC Library\"]\n", - " potential_homogenous_RT_libs = ['BMDMS-NP']# , 'RIKEN Plant Specialized Metabolome Annotation (PlaSMA) Authentic Standard Library' 'BMDMS-NP' , \"Global Natural Product Social Molecular Networking Library\"]\n", - " df[\"RETENTIONTIME\"] = df.apply(lambda x: x[\"RETENTIONTIME\"] if x[\"origin\"] in potential_homogenous_RT_libs else np.nan, axis=1)\n", - " \n", + " # df[df[\"origin\"] == \"MassBank High Quality Mass Spectral Database\"][\"RETENTIONTIME\"] = np.nan\n", + " # df[df[\"origin\"] == \"Fiehn Lab HILIC Library\"][\"RETENTIONTIME\"] = np.nan\n", + " bad_RT_libs = [\n", + " \"MassBank High Quality Mass Spectral Database\",\n", + " \"Fiehn Lab HILIC Library\",\n", + " ]\n", + " potential_homogenous_RT_libs = [\n", + " \"BMDMS-NP\"\n", + " ] # , 'RIKEN Plant Specialized Metabolome Annotation (PlaSMA) Authentic Standard Library' 'BMDMS-NP' , \"Global Natural Product Social Molecular Networking Library\"]\n", + " df[\"RETENTIONTIME\"] = df.apply(\n", + " lambda x: (\n", + " x[\"RETENTIONTIME\"]\n", + " if x[\"origin\"] in potential_homogenous_RT_libs\n", + " else np.nan\n", + " ),\n", + " axis=1,\n", + " )\n", + "\n", " return df\n", "\n", + "\n", "if lib == \"NIST\":\n", - " df = load_nist() \n", + " df = load_nist()\n", "elif lib == \"MSDIAL\":\n", " df = load_msdial()\n", "elif lib == \"NIST/MSDIAL\":\n", " df = pd.concat([load_nist(), load_msdial()], ignore_index=True)\n", - " #df.reset_index(inplace=True) # Avoid conflict from index overlap of the two dataframes\n", - " \n", + " # df.reset_index(inplace=True) # Avoid conflict from index overlap of the two dataframes\n", + "\n", "# Restore dictionary values\n", "dict_columns = [\"peaks\"]\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(ast.literal_eval)\n", - " \n", - "\n" + " df[col] = df[col].apply(ast.literal_eval)" ] }, { @@ -199,11 +220,12 @@ } ], "source": [ - "#%%capture\n", + "# %%capture\n", "from fiora.MOL.Metabolite import Metabolite\n", "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", "from fiora.GNN.SetupFeatureEncoder import SetupFeatureEncoder\n", + "\n", "#\n", "filter_spectra = True\n", "CE_upper_limit = 100.0\n", @@ -211,9 +233,8 @@ "\n", "\n", "if test_run:\n", - " df = df.iloc[5000:6000,:]\n", - " #df = df.iloc[5000:20000,:]\n", - "\n", + " df = df.iloc[5000:6000, :]\n", + " # df = df.iloc[5000:20000,:]\n", "\n", "\n", "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", @@ -221,26 +242,47 @@ "\n", "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", - "setup_encoder = SetupFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\"])\n", - "rt_encoder = SetupFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"])\n", + "setup_encoder = SetupFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " ]\n", + ")\n", + "rt_encoder = SetupFeatureEncoder(\n", + " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"]\n", + ")\n", "\n", "if filter_spectra:\n", - " setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - " setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - " rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", + " setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + " setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + " rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", "\n", - "df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + "df[\"summary\"] = df.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", + ")\n", + "df.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "\n", "if filter_spectra:\n", " num_ori = df.shape[0]\n", - " correct_energy = df[\"Metabolite\"].apply(lambda x: x.metadata[\"collision_energy\"] <= CE_upper_limit and x.metadata[\"collision_energy\"] > 1) \n", + " correct_energy = df[\"Metabolite\"].apply(\n", + " lambda x: (\n", + " x.metadata[\"collision_energy\"] <= CE_upper_limit\n", + " and x.metadata[\"collision_energy\"] > 1\n", + " )\n", + " )\n", " df = df[correct_energy]\n", - " correct_weight = df[\"Metabolite\"].apply(lambda x: x.metadata[\"molecular_weight\"] <= weight_upper_limit)\n", - " df = df[correct_weight] \n", + " correct_weight = df[\"Metabolite\"].apply(\n", + " lambda x: x.metadata[\"molecular_weight\"] <= weight_upper_limit\n", + " )\n", + " df = df[correct_weight]\n", " print(f\"Filtering spectra ({num_ori}) down to {df.shape[0]}\")\n", - " #print(df[\"Precursor_type\"].value_counts())\n" + " # print(df[\"Precursor_type\"].value_counts())" ] }, { @@ -249,8 +291,8 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "import pandas as pd\n", + "\n", "casmi16_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_challenges_combined.csv\"\n", "casmi16train_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_training_combined.csv\"\n", "casmi22_path = f\"{home}/data/metabolites/CASMI_2022/casmi22_challenges_combined.csv\"\n", @@ -289,7 +331,7 @@ "\n", "# for i,d in df_cas22.iterrows():\n", "# m = d[\"Metabolite\"]\n", - " \n", + "\n", "# for x,D in df.iterrows():\n", "# M = D[\"Metabolite\"]\n", "# if (m == M):\n", @@ -299,13 +341,12 @@ "# ces = (d[\"CE\"], D[\"CE\"])\n", "# ddd.append({\"cas16\": i, \"df\": x, \"cas_CE\": d[\"CE\"], \"df_ce\": D[\"CE\"], \"CE_dif\": abs(d[\"CE\"] - D[\"CE\"]), \"cosine\": cos})\n", "# print(f\"Match {i}, {x}, CE: {ces} COSINE: {cos:.3f}\")\n", - " \n", "\n", - "# ii = np.unique(iii) \n", + "\n", + "# ii = np.unique(iii)\n", "# print(f\"Found {len(iii)} instances violating test/train split. Metabolite found in train/val set.\")\n", "# print(f\"Dropping {len(xxx)} spectra from training DataFrame.\")\n", "# # df.drop(xxx, inplace=True)\n", - "\n", "\n" ] }, @@ -376,26 +417,30 @@ "\n", "iii = []\n", "xxx = []\n", - "for i,d in df_cas.iterrows():\n", + "for i, d in df_cas.iterrows():\n", " m = d[\"Metabolite\"]\n", - " \n", - " for x,D in df.iterrows():\n", + "\n", + " for x, D in df.iterrows():\n", " M = D[\"Metabolite\"]\n", - " if (m == M):\n", + " if m == M:\n", " iii += [i]\n", " xxx += [x]\n", " if D[\"CCS\"]:\n", - " df_cas.at[i, \"CCS\"] = df_cas.at[i, \"CCS\"] + [D[\"CCS\"]] # Add CCS metadata\n", + " df_cas.at[i, \"CCS\"] = df_cas.at[i, \"CCS\"] + [\n", + " D[\"CCS\"]\n", + " ] # Add CCS metadata\n", + "\n", "\n", - " \n", - "iii = np.unique(iii) \n", - "print(f\"Found {len(iii)} instances violating test/train split (CASMI 16 Challenge). Metabolite found in train/val set.\")\n", + "iii = np.unique(iii)\n", + "print(\n", + " f\"Found {len(iii)} instances violating test/train split (CASMI 16 Challenge). Metabolite found in train/val set.\"\n", + ")\n", "print(f\"Dropping {len(xxx)} spectra from training DataFrame.\")\n", "df.drop(xxx, inplace=True)\n", "\n", "# Add CCS metadata\n", "df_cas[\"CCS_std\"] = df_cas[\"CCS\"].apply(np.std)\n", - "df_cas[\"CCS\"] = df_cas[\"CCS\"].apply(np.mean)\n" + "df_cas[\"CCS\"] = df_cas[\"CCS\"].apply(np.mean)" ] }, { @@ -435,20 +480,24 @@ "\n", "iii = []\n", "xxx = []\n", - "for i,d in df_cast.iterrows():\n", + "for i, d in df_cast.iterrows():\n", " m = d[\"Metabolite\"]\n", - " \n", - " for x,D in df.iterrows():\n", + "\n", + " for x, D in df.iterrows():\n", " M = D[\"Metabolite\"]\n", - " if (m == M):\n", + " if m == M:\n", " iii += [i]\n", " xxx += [x]\n", " if D[\"CCS\"]:\n", - " df_cast.at[i, \"CCS\"] = df_cast.at[i, \"CCS\"] + [D[\"CCS\"]] # Add CCS metadata\n", + " df_cast.at[i, \"CCS\"] = df_cast.at[i, \"CCS\"] + [\n", + " D[\"CCS\"]\n", + " ] # Add CCS metadata\n", + "\n", "\n", - " \n", - "iii = np.unique(iii) \n", - "print(f\"Found {len(iii)} instances violating test/train split (CASMI 16 Training). Metabolite found in train/val set.\")\n", + "iii = np.unique(iii)\n", + "print(\n", + " f\"Found {len(iii)} instances violating test/train split (CASMI 16 Training). Metabolite found in train/val set.\"\n", + ")\n", "print(f\"Dropping {len(xxx)} spectra from training DataFrame.\")\n", "df.drop(xxx, inplace=True)\n", "\n", @@ -486,7 +535,7 @@ } ], "source": [ - "df_cas22_unique = df_cas22.drop_duplicates(subset='ChallengeName', keep='first')\n", + "df_cas22_unique = df_cas22.drop_duplicates(subset=\"ChallengeName\", keep=\"first\")\n", "df_cas22_unique.reset_index(inplace=True)\n", "df_cas22_unique[\"Metabolite\"] = df_cas22_unique[\"SMILES\"].apply(Metabolite)\n", "df_cas22_unique.shape" @@ -555,28 +604,36 @@ "df_cas22_unique[\"CCS\"] = [[]] * df_cas22_unique.shape[0]\n", "\n", "\n", - "for i,d in df_cas22_unique.iterrows():\n", + "for i, d in df_cas22_unique.iterrows():\n", " m = d[\"Metabolite\"]\n", - " \n", - " for x,D in df.iterrows():\n", + "\n", + " for x, D in df.iterrows():\n", " M = D[\"Metabolite\"]\n", - " if (m == M):\n", + " if m == M:\n", " iii += [i]\n", " xxx += [x]\n", - " \n", + "\n", " if D[\"CCS\"]:\n", - " df_cas22_unique.at[i, \"CCS\"] = df_cas22_unique.at[i, \"CCS\"] + [D[\"CCS\"]] # Add CCS metadata\n", - " \n", - "iii = np.unique(iii) \n", - "print(f\"Found {len(iii)} instances violating test/train split (CASMI 22). Metabolite found in train/val set.\")\n", + " df_cas22_unique.at[i, \"CCS\"] = df_cas22_unique.at[i, \"CCS\"] + [\n", + " D[\"CCS\"]\n", + " ] # Add CCS metadata\n", + "\n", + "iii = np.unique(iii)\n", + "print(\n", + " f\"Found {len(iii)} instances violating test/train split (CASMI 22). Metabolite found in train/val set.\"\n", + ")\n", "print(f\"Dropping {len(xxx)} spectra from training DataFrame.\")\n", "df.drop(xxx, inplace=True)\n", "\n", "# Add CCS metadata\n", "df_cas22_unique[\"CCS_std\"] = df_cas22_unique[\"CCS\"].apply(np.std)\n", "df_cas22_unique[\"CCS\"] = df_cas22_unique[\"CCS\"].apply(np.mean)\n", - "df_cas22[\"CCS\"] = df_cas22[\"ChallengeName\"].apply(lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS\"].iloc[0])\n", - "df_cas22[\"CCS_std\"] = df_cas22[\"ChallengeName\"].apply(lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS_std\"].iloc[0])\n" + "df_cas22[\"CCS\"] = df_cas22[\"ChallengeName\"].apply(\n", + " lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS\"].iloc[0]\n", + ")\n", + "df_cas22[\"CCS_std\"] = df_cas22[\"ChallengeName\"].apply(\n", + " lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS_std\"].iloc[0]\n", + ")" ] }, { @@ -589,8 +646,12 @@ "\n", "# TODO add those to df_cas22\n", "\n", - "df_cas22[\"CCS\"] = df_cas22[\"ChallengeName\"].apply(lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS\"].iloc[0])\n", - "df_cas22[\"CCS_std\"] = df_cas22[\"ChallengeName\"].apply(lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS_std\"].iloc[0])" + "df_cas22[\"CCS\"] = df_cas22[\"ChallengeName\"].apply(\n", + " lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS\"].iloc[0]\n", + ")\n", + "df_cas22[\"CCS_std\"] = df_cas22[\"ChallengeName\"].apply(\n", + " lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS_std\"].iloc[0]\n", + ")" ] }, { @@ -600,15 +661,14 @@ "outputs": [], "source": [ "# Save casmi with\n", - "save_df=False\n", + "save_df = False\n", "if save_df:\n", - "\n", " df_cas.to_csv(f\"{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv\")\n", " df_cast.to_csv(f\"{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv\")\n", " df_cas22.to_csv(f\"{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv\")\n", "\n", " print(df_cas.head(3))\n", - " print(df_cas22.head(3))\n" + " print(df_cas22.head(3))" ] }, { @@ -670,7 +730,7 @@ } ], "source": [ - "df.groupby(\"lib\").group_id.unique().apply(len)\n" + "df.groupby(\"lib\").group_id.unique().apply(len)" ] }, { @@ -708,8 +768,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "#df.groupby('group_id')" + "# df.groupby('group_id')" ] }, { @@ -758,10 +817,18 @@ ], "source": [ "from fiora.MOL.mol_graph import draw_graph\n", - "from fiora.visualization.define_colors import define_figure_style, color_palette, bluepink, bluepink_grad, bluepink_grad8, tri_palette\n", + "from fiora.visualization.define_colors import (\n", + " define_figure_style,\n", + " color_palette,\n", + " bluepink,\n", + " bluepink_grad,\n", + " bluepink_grad8,\n", + " tri_palette,\n", + ")\n", "from fiora.visualization.define_colors import *\n", "import matplotlib.pyplot as plt\n", - "matplotlib.rcParams['figure.figsize'] = (12, 6)\n", + "\n", + "matplotlib.rcParams[\"figure.figsize\"] = (12, 6)\n", "\n", "magma_palette = define_figure_style(style=\"magma-white\", palette_steps=8)\n", "\n", @@ -769,7 +836,9 @@ " EXAMPLE_ID = 32271 if (lib == \"NIST\") else 7607 if lib == \"MSDIAL\" else 0\n", " example = df.loc[EXAMPLE_ID]\n", "\n", - " fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 1]}, sharey=False)\n", + " fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + " )\n", " set_light_theme()\n", "\n", " img = example[\"Metabolite\"].draw(ax=axs[0])\n", @@ -794,12 +863,16 @@ ], "source": [ "if not test_run:\n", - " fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", + " fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " )\n", "\n", - " img = example[\"Metabolite\"].draw(ax= axs[0])\n", + " img = example[\"Metabolite\"].draw(ax=axs[0])\n", "\n", " axs[0].grid(False)\n", - " axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", + " axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " )\n", " axs[0].set_title(str(example[\"Metabolite\"]))\n", " axs[0].imshow(img)\n", " axs[0].axis(\"off\")\n", @@ -814,7 +887,12 @@ "source": [ "%%capture\n", "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]), axis=1)\n" + "df.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -850,7 +928,7 @@ "source": [ "example = df.iloc[10]\n", "example[\"Metabolite\"].match_stats\n", - "#raise KeyboardInterrupt()" + "# raise KeyboardInterrupt()" ] }, { @@ -872,7 +950,9 @@ "source": [ "df[\"np\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"num_peak_matches\"])\n", "df[\"npf\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"num_peak_matches_filtered\"])\n", - "df[\"ppf\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"percent_peak_matches_filtered\"])\n", + "df[\"ppf\"] = df[\"Metabolite\"].apply(\n", + " lambda x: x.match_stats[\"percent_peak_matches_filtered\"]\n", + ")\n", "\n", "sns.violinplot(data=df, y=\"ppf\")\n", "plt.show()" @@ -926,7 +1006,7 @@ } ], "source": [ - "df[\"np\"].value_counts().head(10)\n" + "df[\"np\"].value_counts().head(10)" ] }, { @@ -986,7 +1066,7 @@ "metadata": {}, "outputs": [], "source": [ - "#df = ORI_DF.copy(deep=True)\n", + "# df = ORI_DF.copy(deep=True)\n", "ORI_DF = df.copy(deep=True)" ] }, @@ -1015,11 +1095,17 @@ "source": [ "from fiora.MOL.constants import ADDUCT_WEIGHTS, PPM\n", "\n", - "df[\"Precursor_offset\"] = df[\"PrecursorMZ\"] - df.apply(lambda x: x[\"Metabolite\"].ExactMolWeight + ADDUCT_WEIGHTS[x[\"Precursor_type\"]], axis=1)\n", + "df[\"Precursor_offset\"] = df[\"PrecursorMZ\"] - df.apply(\n", + " lambda x: x[\"Metabolite\"].ExactMolWeight + ADDUCT_WEIGHTS[x[\"Precursor_type\"]],\n", + " axis=1,\n", + ")\n", "df[\"Precursor_abs_error\"] = abs(df[\"Precursor_offset\"])\n", "df[\"Precursor_rel_error\"] = df[\"Precursor_abs_error\"] / df[\"PrecursorMZ\"]\n", "df[\"Precursor_ppm_error\"] = df[\"Precursor_abs_error\"] / (df[\"PrecursorMZ\"] * PPM)\n", - "print((df[\"Precursor_ppm_error\"] > df[\"PPM_num\"]).sum(), \"found with misaligned precursor. Removing these.\")\n", + "print(\n", + " (df[\"Precursor_ppm_error\"] > df[\"PPM_num\"]).sum(),\n", + " \"found with misaligned precursor. Removing these.\",\n", + ")\n", "\n", "df = df[df[\"Precursor_ppm_error\"] <= df[\"PPM_num\"]]" ] @@ -1052,7 +1138,7 @@ ], "source": [ "df[\"coverage\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - "sns.kdeplot(data=df, x=\"coverage\", hue=\"lib\")#, multiple=\"dodge\")" + "sns.kdeplot(data=df, x=\"coverage\", hue=\"lib\") # , multiple=\"dodge\")" ] }, { @@ -1081,29 +1167,32 @@ "source": [ "# TODO Implement conflict solver\n", "\n", - "coverage_tracker = {\"counts\": [], \"all\": [], \"coverage\": [], \"fragment_only_coverage\": [], \"Precursor_type\": []}\n", + "coverage_tracker = {\n", + " \"counts\": [],\n", + " \"all\": [],\n", + " \"coverage\": [],\n", + " \"fragment_only_coverage\": [],\n", + " \"Precursor_type\": [],\n", + "}\n", "\n", "drop_index = []\n", - "for i,d in df.iterrows():\n", + "for i, d in df.iterrows():\n", " M = d[\"Metabolite\"]\n", - " \n", - " \n", + "\n", " coverage_tracker[\"counts\"] += [M.match_stats[\"counts\"]]\n", " coverage_tracker[\"all\"] += [M.match_stats[\"ms_all_counts\"]]\n", " coverage_tracker[\"fragment_only_coverage\"] += [M.match_stats[\"coverage_wo_prec\"]]\n", " coverage_tracker[\"coverage\"] += [M.match_stats[\"coverage\"]]\n", " coverage_tracker[\"Precursor_type\"] += [M.metadata[\"precursor_mode\"]]\n", - " \n", - " \n", + "\n", " #\n", " # IMPORTANT: FILTER AND CLEAN DATA\n", " #\n", - " \n", - " \n", + "\n", " min_coverage = 0.5\n", - " if M.match_stats[\"coverage\"] < min_coverage: # Filter if total coverage is too low\n", + " if M.match_stats[\"coverage\"] < min_coverage: # Filter if total coverage is too low\n", " drop_index.append(i)\n", - " \n", + "\n", " min_peaks = 2\n", " if M.match_stats[\"num_peak_matches_filtered\"] < min_peaks:\n", " drop_index.append(i)\n", @@ -1112,42 +1201,63 @@ " desired_peaks = 5\n", " desired_peak_percentage = 0.5\n", " extremly_high_coverage = 0.8\n", - " if (M.match_stats[\"num_peak_matches_filtered\"] < desired_peaks) & (M.match_stats[\"percent_peak_matches_filtered\"] < desired_peak_percentage) & (M.match_stats[\"coverage\"] < extremly_high_coverage):\n", + " if (\n", + " (M.match_stats[\"num_peak_matches_filtered\"] < desired_peaks)\n", + " & (M.match_stats[\"percent_peak_matches_filtered\"] < desired_peak_percentage)\n", + " & (M.match_stats[\"coverage\"] < extremly_high_coverage)\n", + " ):\n", " drop_index.append(i)\n", "\n", - " #if M.match_stats[\"num_non_precursor_matches\"] < 1:\n", + " # if M.match_stats[\"num_non_precursor_matches\"] < 1:\n", " # drop_index.append(i)\n", " # max_conflicts_rel = 0.75\n", - " # if M.match_stats[\"rel_fragment_conflicts\"] > max_conflicts_rel: \n", + " # if M.match_stats[\"rel_fragment_conflicts\"] > max_conflicts_rel:\n", " # drop_index.append(i)\n", - " \n", + "\n", " max_precursor = 0.9\n", - " if M.match_stats[\"precursor_prob\"] > max_precursor: # Filter if fragment coverage is too low (intensity wise)\n", - " drop_index.append(i)\n", - " \n", + " if (\n", + " M.match_stats[\"precursor_prob\"] > max_precursor\n", + " ): # Filter if fragment coverage is too low (intensity wise)\n", + " drop_index.append(i)\n", + "\n", " # if d[\"lib\"] == \"MSDIAL\":\n", " # if d[\"INSTRUMENTTYPE\"] == \"Orbitrap\": # and M.match_stats[\"precursor_prob\"] > 0.95:\n", " # drop_index.append(i)\n", - " \n", - " \n", + "\n", + "\n", "# filter low res instruments TODO update to low quality spectra\n", - "low_quality_tags = [\"IT/ion trap\", \"QqQ\", \"LC-ESI-QQ\", \"Flow-injection QqQ/MS\", \"LC-APPI-QQ\", \"LC-ESI-IT\", \"LC-ESI-QIT\", \"QIT\"] #What about ESI-ITTOF? GC-APCI-QTOF?\n", - "low_res_machines = df[\"Metabolite\"].apply(lambda x: x.metadata[\"instrument\"] in low_quality_tags)\n", + "low_quality_tags = [\n", + " \"IT/ion trap\",\n", + " \"QqQ\",\n", + " \"LC-ESI-QQ\",\n", + " \"Flow-injection QqQ/MS\",\n", + " \"LC-APPI-QQ\",\n", + " \"LC-ESI-IT\",\n", + " \"LC-ESI-QIT\",\n", + " \"QIT\",\n", + "] # What about ESI-ITTOF? GC-APCI-QTOF?\n", + "low_res_machines = df[\"Metabolite\"].apply(\n", + " lambda x: x.metadata[\"instrument\"] in low_quality_tags\n", + ")\n", "drop_index += list(df[low_res_machines].index)\n", "\n", "fig, axs = plt.subplots(1, 3, figsize=(12.8, 4.2), sharey=True)\n", "\n", - "plt.ylim([-.02,1.02])\n", - "sns.boxplot(ax=axs[0], data=coverage_tracker, y=\"fragment_only_coverage\", color=magma_palette[1])\n", + "plt.ylim([-0.02, 1.02])\n", + "sns.boxplot(\n", + " ax=axs[0], data=coverage_tracker, y=\"fragment_only_coverage\", color=magma_palette[1]\n", + ")\n", "sns.boxplot(ax=axs[1], data=coverage_tracker, y=\"coverage\", color=magma_palette[2])\n", "sns.violinplot(ax=axs[2], data=coverage_tracker, y=\"coverage\", color=magma_palette[2])\n", "axs[0].set_title(\"Coverage of peak intensity (fragments only)\")\n", "axs[1].set_title(\"Coverage of peak intensity\")\n", "axs[2].set_title(\"Coverage of peak intensity\")\n", - "axs[2].axhline(y=min_coverage, color='black', linestyle='--', label='Horizontal Line')\n", + "axs[2].axhline(y=min_coverage, color=\"black\", linestyle=\"--\", label=\"Horizontal Line\")\n", "plt.show()\n", "\n", - "print(f\"Filtering out {len(drop_index)} that have only precursor matches || or || too little (intensity) coverage to make edge prediction possible\")\n", + "print(\n", + " f\"Filtering out {len(drop_index)} that have only precursor matches || or || too little (intensity) coverage to make edge prediction possible\"\n", + ")\n", "df.drop(drop_index, inplace=True)" ] }, @@ -1187,22 +1297,23 @@ "source": [ "def custom_sample(group):\n", " sample_size = 20\n", - " x_origin = group[group['lib'] == 'NIST']\n", - " y_origin = group[group['lib'] == 'MSDIAL']\n", - " \n", + " x_origin = group[group[\"lib\"] == \"NIST\"]\n", + " y_origin = group[group[\"lib\"] == \"MSDIAL\"]\n", + "\n", " if len(x_origin) >= sample_size:\n", " return x_origin.sample(n=sample_size, replace=False)\n", " elif len(y_origin) > 0:\n", " y_sample_size = min(len(y_origin), sample_size - len(x_origin))\n", " return pd.concat([x_origin, y_origin.sample(n=y_sample_size, replace=False)])\n", " else:\n", - " return(x_origin)\n", + " return x_origin\n", + "\n", "\n", - "#TURNED OFF\n", - "#df = df.groupby('group_id', group_keys=False).apply(custom_sample)\n", + "# TURNED OFF\n", + "# df = df.groupby('group_id', group_keys=False).apply(custom_sample)\n", "## df = df.groupby('group_id').apply(lambda group: group.sample(n=10, replace=True) if len(group) >= 10 else group)\n", "\n", - "#df = df.reset_index(drop=True)\n", + "# df = df.reset_index(drop=True)\n", "\n", "num_structures = len(df[\"group_id\"].unique())\n", "print(f\"Train and validate on {num_structures} unique structures\")\n", @@ -1264,11 +1375,11 @@ "from fiora.GNN.Trainer import Trainer\n", "import torch_geometric as geom\n", "\n", - "if torch.cuda.is_available(): \n", - " dev = \"cuda:0\"\n", - "else: \n", - " dev = \"cpu\" \n", - " \n", + "if torch.cuda.is_available():\n", + " dev = \"cuda:0\"\n", + "else:\n", + " dev = \"cpu\"\n", + "\n", "print(f\"Running on device: {dev}\")" ] }, @@ -1286,9 +1397,8 @@ "metadata": {}, "outputs": [], "source": [ - "#df.shape\n", - "DF_BACKUP = df.copy(deep=True)\n", - "\n" + "# df.shape\n", + "DF_BACKUP = df.copy(deep=True)" ] }, { @@ -1319,9 +1429,14 @@ ], "source": [ "df[\"pp\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"precursor_prob\"])\n", - "top_instrumenttypes = df['Instrument_type'].value_counts().head(3).index\n", - "f,a = plt.subplots(1,1, figsize=(8, 4))\n", - "sns.histplot(data=df[df['Instrument_type'].isin(top_instrumenttypes)], x=\"pp\", hue=\"Instrument_type\", multiple=\"dodge\")\n", + "top_instrumenttypes = df[\"Instrument_type\"].value_counts().head(3).index\n", + "f, a = plt.subplots(1, 1, figsize=(8, 4))\n", + "sns.histplot(\n", + " data=df[df[\"Instrument_type\"].isin(top_instrumenttypes)],\n", + " x=\"pp\",\n", + " hue=\"Instrument_type\",\n", + " multiple=\"dodge\",\n", + ")\n", "plt.show()" ] }, @@ -1395,27 +1510,27 @@ "outputs": [], "source": [ "model_params = {\n", - " 'gnn_type': 'RGCNConv',\n", - " 'depth': 3,\n", - " 'hidden_dimension': 300,\n", - " 'dense_layers': 2,\n", - " 'embedding_aggregation': 'concat',\n", - " 'embedding_dimension': 300,\n", - " 'input_dropout': 0.2,\n", - " 'latent_dropout': 0.1,\n", - " 'node_feature_layout': node_encoder.feature_numbers,\n", - " 'edge_feature_layout': bond_encoder.feature_numbers, \n", - " 'static_feature_dimension': geo_data[0][\"static_edge_features\"].shape[1],\n", - " 'static_rt_feature_dimension': geo_data[0][\"static_rt_features\"].shape[1],\n", - " 'output_dimension': len(DEFAULT_MODES) * 2, # per edge \n", + " \"gnn_type\": \"RGCNConv\",\n", + " \"depth\": 3,\n", + " \"hidden_dimension\": 300,\n", + " \"dense_layers\": 2,\n", + " \"embedding_aggregation\": \"concat\",\n", + " \"embedding_dimension\": 300,\n", + " \"input_dropout\": 0.2,\n", + " \"latent_dropout\": 0.1,\n", + " \"node_feature_layout\": node_encoder.feature_numbers,\n", + " \"edge_feature_layout\": bond_encoder.feature_numbers,\n", + " \"static_feature_dimension\": geo_data[0][\"static_edge_features\"].shape[1],\n", + " \"static_rt_feature_dimension\": geo_data[0][\"static_rt_features\"].shape[1],\n", + " \"output_dimension\": len(DEFAULT_MODES) * 2, # per edge\n", "}\n", "training_params = {\n", - " 'epochs': 200 if not test_run else 20, #180,\n", - " 'batch_size': 256, #128,\n", - " 'train_val_split': 0.90,\n", - " 'learning_rate': 0.0004,#0.001,\n", - " 'with_RT': False, # TODO CHANGED\n", - " 'with_CCS': False\n", + " \"epochs\": 200 if not test_run else 20, # 180,\n", + " \"batch_size\": 256, # 128,\n", + " \"train_val_split\": 0.90,\n", + " \"learning_rate\": 0.0004, # 0.001,\n", + " \"with_RT\": False, # TODO CHANGED\n", + " \"with_CCS\": False,\n", "}" ] }, @@ -1427,7 +1542,8 @@ "source": [ "from fiora.GNN.GNNModules import GNNCompiler\n", "from fiora.GNN.Losses import WeightedMSELoss, WeightedMSEMetric\n", - "model = GNNCompiler(model_params).to(dev)\n" + "\n", + "model = GNNCompiler(model_params).to(dev)" ] }, { @@ -1458,13 +1574,19 @@ "source": [ "from sklearn.model_selection import train_test_split\n", "\n", + "\n", "def train_val_test_split(keys, test_size=0.1, val_size=0.1, rseed=seed):\n", - " temp_keys, test_keys = train_test_split(keys, test_size=test_size, random_state=rseed)\n", + " temp_keys, test_keys = train_test_split(\n", + " keys, test_size=test_size, random_state=rseed\n", + " )\n", " adjusted_val_size = val_size / (1 - test_size)\n", - " train_keys, val_keys = train_test_split(temp_keys, test_size=adjusted_val_size, random_state=rseed)\n", - " \n", + " train_keys, val_keys = train_test_split(\n", + " temp_keys, test_size=adjusted_val_size, random_state=rseed\n", + " )\n", + "\n", " return train_keys, val_keys, test_keys\n", "\n", + "\n", "# Make sure that the example is in the test split\n", "if not test_run:\n", " ex_smiles = \"CC(NC(=O)CC1=CNC2=C1C=CC=C2)C(O)=O\"\n", @@ -1478,9 +1600,21 @@ "for i in range(100):\n", " train, val, test = train_val_test_split(keys, rseed=seed + i)\n", " if test_run or (ex_compound_id in test):\n", - " print(f\"Seed {seed + i} used to sample slits, such that the example Metabolite is in the test set.\")\n", + " print(\n", + " f\"Seed {seed + i} used to sample slits, such that the example Metabolite is in the test set.\"\n", + " )\n", " break\n", - "df[\"dataset\"] = df[\"group_id\"].apply(lambda x: 'train' if x in train else 'validation' if x in val else 'test' if x in test else 'VALUE ERROR')\n" + "df[\"dataset\"] = df[\"group_id\"].apply(\n", + " lambda x: (\n", + " \"train\"\n", + " if x in train\n", + " else \"validation\"\n", + " if x in val\n", + " else \"test\"\n", + " if x in test\n", + " else \"VALUE ERROR\"\n", + " )\n", + ")" ] }, { @@ -1497,16 +1631,31 @@ } ], "source": [ - "y_label = 'compiled_probsALL'\n", - "train_keys, val_keys = df[df[\"dataset\"] == \"train\"][\"group_id\"].unique(), df[df[\"dataset\"] == \"validation\"][\"group_id\"].unique()\n", - "\n", - "trainer = Trainer(geo_data, y_tag=y_label, problem_type=\"regression\", metric_dict={\"mse\": WeightedMSEMetric}, train_keys=train_keys, val_keys=val_keys, split_by_group=True, seed=seed, device=dev)\n", + "y_label = \"compiled_probsALL\"\n", + "train_keys, val_keys = (\n", + " df[df[\"dataset\"] == \"train\"][\"group_id\"].unique(),\n", + " df[df[\"dataset\"] == \"validation\"][\"group_id\"].unique(),\n", + ")\n", + "\n", + "trainer = Trainer(\n", + " geo_data,\n", + " y_tag=y_label,\n", + " problem_type=\"regression\",\n", + " metric_dict={\"mse\": WeightedMSEMetric},\n", + " train_keys=train_keys,\n", + " val_keys=val_keys,\n", + " split_by_group=True,\n", + " seed=seed,\n", + " device=dev,\n", + ")\n", "optimizer = torch.optim.Adam(model.parameters(), lr=training_params[\"learning_rate\"])\n", - "#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98) \n", - "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 8, factor=0.5, mode = 'min', verbose = True)\n", + "# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n", + "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optimizer, patience=8, factor=0.5, mode=\"min\", verbose=True\n", + ")\n", "\n", "loss_fn = WeightedMSELoss()\n", - "#loss_fn = torch.nn.MSELoss()" + "# loss_fn = torch.nn.MSELoss()" ] }, { @@ -1736,7 +1885,19 @@ ], "source": [ "tag = \"b1\"\n", - "checkpoints = trainer.train(model, optimizer, loss_fn, scheduler=scheduler, batch_size=training_params['batch_size'], epochs=training_params[\"epochs\"], val_every_n_epochs=1, with_CCS=training_params[\"with_CCS\"], with_RT=training_params[\"with_RT\"], masked_validation=False, tag=tag) #, mask_name=\"compiled_validation_maskALL\")" + "checkpoints = trainer.train(\n", + " model,\n", + " optimizer,\n", + " loss_fn,\n", + " scheduler=scheduler,\n", + " batch_size=training_params[\"batch_size\"],\n", + " epochs=training_params[\"epochs\"],\n", + " val_every_n_epochs=1,\n", + " with_CCS=training_params[\"with_CCS\"],\n", + " with_RT=training_params[\"with_RT\"],\n", + " masked_validation=False,\n", + " tag=tag,\n", + ") # , mask_name=\"compiled_validation_maskALL\")" ] }, { @@ -1767,8 +1928,7 @@ " print(f\"Loading model from checkpoint {checkpoints}.\")\n", " end_model = GNNCompiler(model_params).to(dev)\n", " end_model = end_model.load_state_dict(model.state_dict())\n", - " model = model.load(checkpoints[\"file\"]).to(dev)\n", - " " + " model = model.load(checkpoints[\"file\"]).to(dev)" ] }, { @@ -1796,10 +1956,23 @@ "source": [ "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", - "df[\"dataset\"] = df[\"group_id\"].apply(lambda x: \"training\" if trainer.is_group_in_training_set(x) else \"validation\" if trainer.is_group_in_validation_set(x) else \"test\") \n", + "df[\"dataset\"] = df[\"group_id\"].apply(\n", + " lambda x: (\n", + " \"training\"\n", + " if trainer.is_group_in_training_set(x)\n", + " else \"validation\"\n", + " if trainer.is_group_in_validation_set(x)\n", + " else \"test\"\n", + " )\n", + ")\n", "\n", - "fiora = SimulationFramework(model, dev=dev, with_RT=training_params[\"with_RT\"], with_CCS=training_params[\"with_CCS\"])\n", - "df = fiora.simulate_all(df, model)\n" + "fiora = SimulationFramework(\n", + " model,\n", + " dev=dev,\n", + " with_RT=training_params[\"with_RT\"],\n", + " with_CCS=training_params[\"with_CCS\"],\n", + ")\n", + "df = fiora.simulate_all(df, model)" ] }, { @@ -1848,17 +2021,23 @@ " cosine = data[\"spectral_sqrt_cosine\"]\n", " name = data[\"Name\"]\n", " print(f\"{name} ({i}): cosine {cosine:0.2}\")\n", - " fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - " img = data[\"Metabolite\"].draw(ax= axs[0])\n", - "\n", - " #axs[0].grid(False)\n", - " axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", + " fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " )\n", + " img = data[\"Metabolite\"].draw(ax=axs[0])\n", + "\n", + " # axs[0].grid(False)\n", + " axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " )\n", " axs[0].set_title(data[\"Name\"])\n", - " #axs[0].imshow(img)\n", - " #axs[0].axis(\"off\")\n", - " #sv.plot_spectrum(example, ax=axs[1])\n", - " ax = sv.plot_spectrum(data, {\"peaks\": data[\"sim_peaks\"]}, ax=axs[1], highlight_matches=False)\n", - " plt.show()\n" + " # axs[0].imshow(img)\n", + " # axs[0].axis(\"off\")\n", + " # sv.plot_spectrum(example, ax=axs[1])\n", + " ax = sv.plot_spectrum(\n", + " data, {\"peaks\": data[\"sim_peaks\"]}, ax=axs[1], highlight_matches=False\n", + " )\n", + " plt.show()" ] }, { @@ -1970,17 +2149,43 @@ } ], "source": [ - "df[[\"cosine_similarity\", \"kl_div\", \"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_refl_cosine\", \"spectral_bias\", \"spectral_sqrt_bias\", \"spectral_refl_bias\"]] = df[[\"cosine_similarity\", \"kl_div\", \"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_refl_cosine\", \"spectral_bias\", \"spectral_sqrt_bias\", \"spectral_refl_bias\"]].astype(float)\n", - "#df[\"is_peptide\"] = df[\"Notes\"].apply(lambda x: \"Peptide\" in x) if lib == \"NIST\" else False\n", + "df[\n", + " [\n", + " \"cosine_similarity\",\n", + " \"kl_div\",\n", + " \"spectral_cosine\",\n", + " \"spectral_sqrt_cosine\",\n", + " \"spectral_refl_cosine\",\n", + " \"spectral_bias\",\n", + " \"spectral_sqrt_bias\",\n", + " \"spectral_refl_bias\",\n", + " ]\n", + "] = df[\n", + " [\n", + " \"cosine_similarity\",\n", + " \"kl_div\",\n", + " \"spectral_cosine\",\n", + " \"spectral_sqrt_cosine\",\n", + " \"spectral_refl_cosine\",\n", + " \"spectral_bias\",\n", + " \"spectral_sqrt_bias\",\n", + " \"spectral_refl_bias\",\n", + " ]\n", + "].astype(float)\n", + "# df[\"is_peptide\"] = df[\"Notes\"].apply(lambda x: \"Peptide\" in x) if lib == \"NIST\" else False\n", "\n", "df_train = df[df[\"dataset\"] == \"training\"]\n", - "df_val = df[df[\"dataset\"]==\"validation\"]\n", + "df_val = df[df[\"dataset\"] == \"validation\"]\n", "df_val[\"library\"] = \"Validation\"\n", "\n", "for key in example[\"Metabolite\"].match_stats.keys():\n", " df_val[key] = df[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n", "\n", - "df_val[\"ring_proportion\"] = df[\"Metabolite\"].apply(lambda x: (getattr(x, \"is_edge_in_ring\").sum() / getattr(x, \"is_edge_in_ring\").shape[0]).tolist())\n" + "df_val[\"ring_proportion\"] = df[\"Metabolite\"].apply(\n", + " lambda x: (\n", + " getattr(x, \"is_edge_in_ring\").sum() / getattr(x, \"is_edge_in_ring\").shape[0]\n", + " ).tolist()\n", + ")" ] }, { @@ -2130,7 +2335,7 @@ "# img = Phacidin.draw(ax= axs[0])\n", "\n", "# axs[0].grid(False)\n", - "# axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", + "# axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", "# axs[0].set_title(\"Phacidin\")\n", "# axs[0].imshow(img)\n", "# axs[0].axis(\"off\")\n", @@ -2195,12 +2400,12 @@ } ], "source": [ - "#from fiora.visualization.define_colors import bluepink_grad8\n", + "# from fiora.visualization.define_colors import bluepink_grad8\n", "\n", "sns.palplot(bluepink)\n", "plt.show()\n", "sns.palplot(bluepink_grad8)\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -2223,12 +2428,20 @@ ], "source": [ "print(\"Stats at first glance (training and validation)\")\n", - "keys = [\"cosine_similarity\", \"kl_div\", \"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_refl_cosine\"]\n", + "keys = [\n", + " \"cosine_similarity\",\n", + " \"kl_div\",\n", + " \"spectral_cosine\",\n", + " \"spectral_sqrt_cosine\",\n", + " \"spectral_refl_cosine\",\n", + "]\n", "\n", "for key in keys:\n", " blue = PRINT_COL[\"blue\"]\n", " end = PRINT_COL[\"end\"]\n", - " print(f\"Median {key}: \\t{df_train[key].median():.2f} {blue} {df_val[key].median():.2f} {end}\")" + " print(\n", + " f\"Median {key}: \\t{df_train[key].median():.2f} {blue} {df_val[key].median():.2f} {end}\"\n", + " )" ] }, { @@ -2265,9 +2478,15 @@ "source": [ "fig, axs = plt.subplots(1, 4, figsize=(18, 6), sharey=False)\n", "\n", - "sns.boxplot(ax=axs[0], data=df, y=\"spectral_cosine\", hue=\"dataset\", palette=bluepink[:3])\n", - "sns.boxplot(ax=axs[1], data=df, y=\"spectral_sqrt_cosine\", hue=\"dataset\", palette=bluepink[:3])\n", - "sns.boxplot(ax=axs[2], data=df, y=\"spectral_refl_cosine\", hue=\"dataset\", palette=bluepink[:3])\n", + "sns.boxplot(\n", + " ax=axs[0], data=df, y=\"spectral_cosine\", hue=\"dataset\", palette=bluepink[:3]\n", + ")\n", + "sns.boxplot(\n", + " ax=axs[1], data=df, y=\"spectral_sqrt_cosine\", hue=\"dataset\", palette=bluepink[:3]\n", + ")\n", + "sns.boxplot(\n", + " ax=axs[2], data=df, y=\"spectral_refl_cosine\", hue=\"dataset\", palette=bluepink[:3]\n", + ")\n", "sns.boxplot(ax=axs[3], data=df, y=\"steins_cosine\", hue=\"dataset\", palette=bluepink[:3])\n", "axs[0].set_title(\"Spectral cosine similarity\")\n", "axs[1].set_title(\"Spectral cosine similarity of sqrt intensities\")\n", @@ -2296,8 +2515,12 @@ "fig, axs = plt.subplots(1, 4, figsize=(18, 6), sharey=False)\n", "\n", "sns.boxplot(ax=axs[0], data=df, y=\"spectral_bias\", hue=\"dataset\", palette=bluepink[:3])\n", - "sns.boxplot(ax=axs[1], data=df, y=\"spectral_sqrt_bias\", hue=\"dataset\", palette=bluepink[:3])\n", - "sns.boxplot(ax=axs[2], data=df, y=\"spectral_refl_bias\", hue=\"dataset\", palette=bluepink[:3])\n", + "sns.boxplot(\n", + " ax=axs[1], data=df, y=\"spectral_sqrt_bias\", hue=\"dataset\", palette=bluepink[:3]\n", + ")\n", + "sns.boxplot(\n", + " ax=axs[2], data=df, y=\"spectral_refl_bias\", hue=\"dataset\", palette=bluepink[:3]\n", + ")\n", "sns.boxplot(ax=axs[3], data=df, y=\"steins_bias\", hue=\"dataset\", palette=bluepink[:3])\n", "axs[0].set_title(\"Spectral cosine bias\")\n", "axs[1].set_title(\"Spectral cosine bias of sqrt intensities\")\n", @@ -2334,7 +2557,13 @@ "source": [ "fig, ax = plt.subplots(1, 1, figsize=(12, 6), sharey=False)\n", "\n", - "sns.boxplot(ax=ax, data=df_val, y=\"spectral_cosine\", hue=\"Instrument_type\", palette=bluepink_grad8)\n", + "sns.boxplot(\n", + " ax=ax,\n", + " data=df_val,\n", + " y=\"spectral_cosine\",\n", + " hue=\"Instrument_type\",\n", + " palette=bluepink_grad8,\n", + ")\n", "axs[0].set_title(\"Cosine\")\n", "plt.show()" ] @@ -2384,25 +2613,37 @@ "\n", "\n", "if filter_spectra:\n", - " setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - " setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - "df_cas[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", - "\n", - "metadata_key_map = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\"\n", - " }\n", - "\n", - "df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1)\n", + " setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + " setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "df_cas[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", + "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "\n", + "metadata_key_map = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"PRECURSOR_MZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + "}\n", + "\n", + "df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", + ")\n", + "df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1\n", + ")\n", "\n", "# Fragmentation\n", "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_cas.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM), axis=1) # Optional: use mz_cut instead\n" + "df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " ),\n", + " axis=1,\n", + ") # Optional: use mz_cut instead" ] }, { @@ -2419,23 +2660,35 @@ "df_cast[\"Metabolite\"] = df_cast[\"SMILES\"].apply(Metabolite)\n", "df_cast[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cast[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cast[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cast[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", - "\n", - "metadata_key_map16 = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\"\n", - " }\n", + "df_cast[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cast[\"CE\"] = 20.0 # actually stepped 20/35/50\n", + "df_cast[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "\n", + "metadata_key_map16 = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"PRECURSOR_MZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + "}\n", "\n", - "df_cast[\"summary\"] = df_cast.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - "df_cast.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1)\n", + "df_cast[\"summary\"] = df_cast.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + ")\n", + "df_cast.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1\n", + ")\n", "\n", "# Fragmentation\n", "df_cast[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_cast.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM), axis=1)\n" + "df_cast.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -2459,14 +2712,14 @@ "# xxx = []\n", "# for i,d in df_cas.iterrows():\n", "# m = d[\"Metabolite\"]\n", - " \n", + "\n", "# for x,D in df.iterrows():\n", "# M = D[\"Metabolite\"]\n", "# if (m == M):\n", "# iii += [i]\n", "# xxx += [x]\n", - " \n", - "# iii = np.unique(iii) \n", + "\n", + "# iii = np.unique(iii)\n", "# print(iii)\n", "# print(f\"Found {len(iii)} instances violating test/train split. Metabolite found in train/val set.\")" ] @@ -2488,12 +2741,16 @@ } ], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", "\n", - "img = df_cas.loc[0][\"Metabolite\"].draw(ax= axs[0])\n", + "img = df_cas.loc[0][\"Metabolite\"].draw(ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", "axs[0].set_title(str(df_cas.loc[0][\"Metabolite\"]))\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", @@ -2584,40 +2841,87 @@ ], "source": [ "from fiora.MOL.collision_energy import NCE_to_eV\n", - "from fiora.MS.spectral_scores import spectral_cosine, spectral_reflection_cosine, reweighted_dot\n", + "from fiora.MS.spectral_scores import (\n", + " spectral_cosine,\n", + " spectral_reflection_cosine,\n", + " reweighted_dot,\n", + ")\n", "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", + "\n", "def test_cas(df_cas):\n", - " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", "\n", - " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", "\n", - "\n", - " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", "\n", - " df_cas[\"avg_CE\"] = (df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]) / 3\n", - "\n", - " df_cas[\"merged_peaks\"] = df_cas.apply(lambda x: merge_annotated_spectrum(merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]), x[\"sim_peaks_50\"]) , axis=1)\n", - " df_cas[\"merged_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(lambda x: spectral_reflection_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_steins\"] = df_cas.apply(lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\"merged_sqrt_cosine\"] # just remember it is merged\n", + " df_cas[\"avg_CE\"] = (\n", + " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " ) / 3\n", + "\n", + " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " lambda x: merge_annotated_spectrum(\n", + " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", + " x[\"sim_peaks_50\"],\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_reflection_cosine(\n", + " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_steins\"] = df_cas.apply(\n", + " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", + " \"merged_sqrt_cosine\"\n", + " ] # just remember it is merged\n", "\n", " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", " df_cas[\"RT_pred\"] = df_cas[\"RT_pred_35\"]\n", @@ -2627,6 +2931,7 @@ "\n", " return df_cas\n", "\n", + "\n", "df_cas = test_cas(df_cas)\n", "df_cast = test_cas(df_cast)" ] @@ -2668,13 +2973,25 @@ } ], "source": [ - "fig, ax = plt.subplots(1,1, figsize=(8,4))\n", - "sns.histplot(df_cas, x=\"avg_CE\", hue=\"Precursor_type\", multiple=\"stack\", palette=[\"black\", \"gray\"]) #bluepink[:2][::-1])\n", + "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", + "sns.histplot(\n", + " df_cas,\n", + " x=\"avg_CE\",\n", + " hue=\"Precursor_type\",\n", + " multiple=\"stack\",\n", + " palette=[\"black\", \"gray\"],\n", + ") # bluepink[:2][::-1])\n", "plt.xlabel(\"Average collision energy\")\n", "plt.show()\n", "\n", - "fig, ax = plt.subplots(1,1, figsize=(8,4))\n", - "sns.histplot(df_cast, x=\"avg_CE\", hue=\"Precursor_type\", multiple=\"stack\", palette=[\"black\", \"gray\"]) #bluepink[:2][::-1])\n", + "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", + "sns.histplot(\n", + " df_cast,\n", + " x=\"avg_CE\",\n", + " hue=\"Precursor_type\",\n", + " multiple=\"stack\",\n", + " palette=[\"black\", \"gray\"],\n", + ") # bluepink[:2][::-1])\n", "plt.xlabel(\"Average collision energy\")\n", "plt.show()" ] @@ -2740,17 +3057,25 @@ "i = 3\n", "print(df_cas.loc[i][\"merged_sqrt_cosine\"])\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - "img = df_cas.loc[i][\"Metabolite\"].draw(ax= axs[0])\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", + "img = df_cas.loc[i][\"Metabolite\"].draw(ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", "axs[0].set_title(df_cas.loc[i][\"NAME\"])\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", - "#sv.plot_spectrum(example, ax=axs[1])\n", - "ax = sv.plot_spectrum({\"peaks\": df_cas.loc[i][\"peaks\"]}, {\"peaks\": df_cas.loc[i][\"sim_peaks_35\"]}, ax=axs[1])\n", - "#axs[1].text(0.5, 0.5, 'matplotlib', horizontalalignment='center', verticalalignment='center', transform=axs[1].transAxes)\n", + "# sv.plot_spectrum(example, ax=axs[1])\n", + "ax = sv.plot_spectrum(\n", + " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", + " {\"peaks\": df_cas.loc[i][\"sim_peaks_35\"]},\n", + " ax=axs[1],\n", + ")\n", + "# axs[1].text(0.5, 0.5, 'matplotlib', horizontalalignment='center', verticalalignment='center', transform=axs[1].transAxes)\n", "plt.show()" ] }, @@ -2785,7 +3110,13 @@ "fig, ax = plt.subplots(1, 1, figsize=(8, 4), sharey=False)\n", "\n", "\n", - "sns.boxplot(ax=ax, data=df_cas, y=\"merged_sqrt_cosine\", x=\"Precursor_type\", palette=bluepink[:2][::-1])\n", + "sns.boxplot(\n", + " ax=ax,\n", + " data=df_cas,\n", + " y=\"merged_sqrt_cosine\",\n", + " x=\"Precursor_type\",\n", + " palette=bluepink[:2][::-1],\n", + ")\n", "ax.set_title(\"Spectral cosine similarity of sqrt intensities\")\n", "plt.show()" ] @@ -2809,7 +3140,8 @@ " abs_error = abs(ref - CE)\n", " i = np.argmin(abs_error)\n", " return str(ref[i])\n", - " \n", + "\n", + "\n", "df_cas[\"cfm_CE\"] = df_cas[\"avg_CE\"].apply(closest_cfm_ce)" ] }, @@ -2830,8 +3162,12 @@ "import fiora.IO.cfmReader as cfmReader\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", - "cf = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi16_negative_predictions.txt\", as_df=True)\n", - "cf_p = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt\", as_df=True)\n", + "cf = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi16_negative_predictions.txt\", as_df=True\n", + ")\n", + "cf_p = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt\", as_df=True\n", + ")\n", "cf = pd.concat([cf, cf_p])" ] }, @@ -2852,7 +3188,7 @@ } ], "source": [ - "len(cf[cf[\"#ID\"] == \"Challenge-009\"]) ## missing chalenges" + "len(cf[cf[\"#ID\"] == \"Challenge-009\"]) ## missing chalenges" ] }, { @@ -2889,15 +3225,18 @@ " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", " continue\n", " cfm_data = cf[cf[\"#ID\"] == challenge].iloc[0]\n", - " \n", - " \n", + "\n", " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", + " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", " df_cas.at[i, \"cfm_peaks\"] = cfm_peaks\n", " df_cas.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cas.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", - " df_cas.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", + " df_cas.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", + " df_cas.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", " df_cas.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)" ] }, @@ -2918,14 +3257,39 @@ } ], "source": [ - "cosines = {\"cosine\": df_cas[\"merged_cosine\"], \"sqrt_cosine\": df_cas[\"merged_sqrt_cosine\"], \"refl_cosine\": df_cas[\"merged_refl_cosine\"], \"steins_dot\": df_cas[\"merged_steins\"], \"model\": \"Fiora\"}\n", - "cosines2 = {\"cosine\": df_cas[\"cfm_cosine\"], \"sqrt_cosine\": df_cas[\"cfm_sqrt_cosine\"],\"refl_cosine\": df_cas[\"cfm_refl_cosine\"], \"steins_dot\": df_cas[\"cfm_steins\"] , \"model\": \"CFM-ID 4.4.7\"}\n", - "fig, axs = plt.subplots(1, 3, figsize=(18.8, 4.2), gridspec_kw={'width_ratios': [1, 1, 1]}, sharey=False)\n", + "cosines = {\n", + " \"cosine\": df_cas[\"merged_cosine\"],\n", + " \"sqrt_cosine\": df_cas[\"merged_sqrt_cosine\"],\n", + " \"refl_cosine\": df_cas[\"merged_refl_cosine\"],\n", + " \"steins_dot\": df_cas[\"merged_steins\"],\n", + " \"model\": \"Fiora\",\n", + "}\n", + "cosines2 = {\n", + " \"cosine\": df_cas[\"cfm_cosine\"],\n", + " \"sqrt_cosine\": df_cas[\"cfm_sqrt_cosine\"],\n", + " \"refl_cosine\": df_cas[\"cfm_refl_cosine\"],\n", + " \"steins_dot\": df_cas[\"cfm_steins\"],\n", + " \"model\": \"CFM-ID 4.4.7\",\n", + "}\n", + "fig, axs = plt.subplots(\n", + " 1, 3, figsize=(18.8, 4.2), gridspec_kw={\"width_ratios\": [1, 1, 1]}, sharey=False\n", + ")\n", "\n", "scores = pd.concat([pd.DataFrame(cosines), pd.DataFrame(cosines2)], ignore_index=True)\n", - "sns.histplot(scores, ax=axs[0], x=\"cosine\", hue=\"model\", palette=bluepink[:2], binwidth=0.025)\n", - "sns.histplot(scores, ax=axs[1], x=\"sqrt_cosine\", hue=\"model\", palette=bluepink[:2],binwidth=0.025)\n", - "sns.histplot(scores, ax=axs[2], x=\"steins_dot\", hue=\"model\", palette=bluepink[:2],binwidth=0.025)\n", + "sns.histplot(\n", + " scores, ax=axs[0], x=\"cosine\", hue=\"model\", palette=bluepink[:2], binwidth=0.025\n", + ")\n", + "sns.histplot(\n", + " scores,\n", + " ax=axs[1],\n", + " x=\"sqrt_cosine\",\n", + " hue=\"model\",\n", + " palette=bluepink[:2],\n", + " binwidth=0.025,\n", + ")\n", + "sns.histplot(\n", + " scores, ax=axs[2], x=\"steins_dot\", hue=\"model\", palette=bluepink[:2], binwidth=0.025\n", + ")\n", "plt.show()" ] }, @@ -2937,6 +3301,7 @@ "source": [ "from matplotlib.patches import PathPatch\n", "\n", + "\n", "def adjust_box_widthsOld(fig, fac):\n", " \"\"\"\n", " Adjust the withs of a seaborn-generated boxplot.\n", @@ -2944,10 +3309,8 @@ "\n", " # iterating through Axes instances\n", " for ax in fig.axes:\n", - "\n", " # iterating through axes artists:\n", " for c in ax.get_children():\n", - "\n", " # searching for PathPatches\n", " if isinstance(c, PathPatch):\n", " # getting current width of box:\n", @@ -2956,12 +3319,12 @@ " verts_sub = verts[:-1]\n", " xmin = np.min(verts_sub[:, 0])\n", " xmax = np.max(verts_sub[:, 0])\n", - " xmid = 0.5*(xmin+xmax)\n", - " xhalf = 0.5*(xmax - xmin)\n", + " xmid = 0.5 * (xmin + xmax)\n", + " xhalf = 0.5 * (xmax - xmin)\n", "\n", " # setting new width of box\n", - " xmin_new = xmid-fac*xhalf\n", - " xmax_new = xmid+fac*xhalf\n", + " xmin_new = xmid - fac * xhalf\n", + " xmax_new = xmid + fac * xhalf\n", " verts_sub[verts_sub[:, 0] == xmin, 0] = xmin_new\n", " verts_sub[verts_sub[:, 0] == xmax, 0] = xmax_new\n", "\n", @@ -2969,12 +3332,11 @@ " for l in ax.lines:\n", " if np.all(l.get_xdata() == [xmin, xmax]):\n", " l.set_xdata([xmin_new, xmax_new])\n", - " \n", - " \n", "\n", "\n", "from matplotlib.patches import PathPatch\n", "\n", + "\n", "def adjust_box_widths(fig, fac):\n", " \"\"\"\n", " Adjust the widths of a seaborn-generated boxplot.\n", @@ -2982,10 +3344,8 @@ "\n", " # iterating through Axes instances\n", " for ax in fig.axes:\n", - "\n", " # iterating through axes artists\n", " for c in ax.get_children():\n", - "\n", " # searching for PathPatches\n", " if isinstance(c, PathPatch):\n", " # getting current width of box\n", @@ -3008,8 +3368,11 @@ " # check if the line has data\n", " if len(l.get_xdata()) > 0:\n", " # check if the line is a median line\n", - " if 'color' in l.properties() and l.properties()['color'] == 'black':\n", - " l.set_xdata([xmin_new, xmax_new])\n" + " if (\n", + " \"color\" in l.properties()\n", + " and l.properties()[\"color\"] == \"black\"\n", + " ):\n", + " l.set_xdata([xmin_new, xmax_new])" ] }, { @@ -3076,9 +3439,9 @@ } ], "source": [ - "print(\"Mean:\\t\", round(df_cas[\"merged_sqrt_cosine\"].mean(), 2)) #0.634\n", - "print(\"Median:\\t\", round(df_cas[\"merged_sqrt_cosine\"].median(), 2)) #0.737\n", - "print(\"Var:\\t\", round(df_cas[\"merged_sqrt_cosine\"].var(), 2)) #0.116" + "print(\"Mean:\\t\", round(df_cas[\"merged_sqrt_cosine\"].mean(), 2)) # 0.634\n", + "print(\"Median:\\t\", round(df_cas[\"merged_sqrt_cosine\"].median(), 2)) # 0.737\n", + "print(\"Var:\\t\", round(df_cas[\"merged_sqrt_cosine\"].var(), 2)) # 0.116" ] }, { @@ -3117,7 +3480,9 @@ ], "source": [ "df_cas[\"higher_cosine\"] = (df_cas[\"merged_sqrt_cosine\"] - df_cas[\"cfm_sqrt_cosine\"]) > 0\n", - "df_cas[\"smaller_cosine\"] = (df_cas[\"merged_sqrt_cosine\"] - df_cas[\"cfm_sqrt_cosine\"]) < 0\n", + "df_cas[\"smaller_cosine\"] = (\n", + " df_cas[\"merged_sqrt_cosine\"] - df_cas[\"cfm_sqrt_cosine\"]\n", + ") < 0\n", "h, l = sum(df_cas[\"higher_cosine\"]), sum(df_cas[\"smaller_cosine\"])\n", "print(f\"Higher in {h} of cases (smaller in {l} cases) out of {df_cas.shape[0]}\")" ] @@ -3170,35 +3535,60 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "import matplotlib.patches as mpatches\n", "\n", + "\n", "def double_mirrorplot(i, model_title=\"Fiora\"):\n", - " fig, axs = plt.subplots(1, 3, figsize=(16.8, 4.2), gridspec_kw={'width_ratios': [1, 3, 3]}, sharey=False)\n", - " \n", + " fig, axs = plt.subplots(\n", + " 1, 3, figsize=(16.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3, 3]}, sharey=False\n", + " )\n", + "\n", " plt.subplots_adjust(right=0.975, left=0.025)\n", - " \n", - " img = df_cas.loc[i][\"Metabolite\"].draw(ax= axs[0])\n", + "\n", + " img = df_cas.loc[i][\"Metabolite\"].draw(ax=axs[0])\n", "\n", " axs[0].grid(False)\n", - " axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", - " axs[0].set_title(df_cas.loc[i][\"NAME\"] + \"\\n(\" + df_cas.loc[i][\"ChallengeName\"]+ \")\")\n", + " axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " )\n", + " axs[0].set_title(\n", + " df_cas.loc[i][\"NAME\"] + \"\\n(\" + df_cas.loc[i][\"ChallengeName\"] + \")\"\n", + " )\n", " axs[0].imshow(img)\n", " axs[0].axis(\"off\")\n", "\n", - " sv.plot_spectrum({\"peaks\": df_cas.loc[i][\"peaks\"]}, {\"peaks\": df_cas.loc[i][\"merged_peaks\"]}, ax=axs[1])\n", + " sv.plot_spectrum(\n", + " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", + " {\"peaks\": df_cas.loc[i][\"merged_peaks\"]},\n", + " ax=axs[1],\n", + " )\n", " axs[1].title.set_text(model_title)\n", - " patch1 = mpatches.Patch(color='limegreen' if df_cas.loc[i][\"cfm_sqrt_cosine\"] < df_cas.loc[i][\"merged_sqrt_cosine\"] else \"orangered\", label=f'cosine {df_cas.loc[i][\"merged_sqrt_cosine\"]:.02f}')\n", + " patch1 = mpatches.Patch(\n", + " color=\"limegreen\"\n", + " if df_cas.loc[i][\"cfm_sqrt_cosine\"] < df_cas.loc[i][\"merged_sqrt_cosine\"]\n", + " else \"orangered\",\n", + " label=f\"cosine {df_cas.loc[i]['merged_sqrt_cosine']:.02f}\",\n", + " )\n", " axs[1].legend(handles=[patch1])\n", "\n", - " sv.plot_spectrum({\"peaks\": df_cas.loc[i][\"peaks\"]}, {\"peaks\": df_cas.loc[i][\"cfm_peaks\"]} if df_cas.loc[i][\"cfm_peaks\"] else {\"peaks\": {\"mz\": [0], \"intensity\": [0]}}, ax=axs[2])\n", - " axs[2].title.set_text(f'CFM-ID 4.4.7')\n", - " \n", - " patch2 = mpatches.Patch(color='limegreen' if df_cas.loc[i][\"cfm_sqrt_cosine\"] > df_cas.loc[i][\"merged_sqrt_cosine\"] else \"orangered\", label=f'cosine {df_cas.loc[i][\"cfm_sqrt_cosine\"]:.02f}', )\n", + " sv.plot_spectrum(\n", + " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", + " {\"peaks\": df_cas.loc[i][\"cfm_peaks\"]}\n", + " if df_cas.loc[i][\"cfm_peaks\"]\n", + " else {\"peaks\": {\"mz\": [0], \"intensity\": [0]}},\n", + " ax=axs[2],\n", + " )\n", + " axs[2].title.set_text(f\"CFM-ID 4.4.7\")\n", + "\n", + " patch2 = mpatches.Patch(\n", + " color=\"limegreen\"\n", + " if df_cas.loc[i][\"cfm_sqrt_cosine\"] > df_cas.loc[i][\"merged_sqrt_cosine\"]\n", + " else \"orangered\",\n", + " label=f\"cosine {df_cas.loc[i]['cfm_sqrt_cosine']:.02f}\",\n", + " )\n", " axs[2].legend(handles=[patch2])\n", - " \n", - " return fig, axs\n", - "\n" + "\n", + " return fig, axs" ] }, { @@ -3211,8 +3601,8 @@ "if plot_examples:\n", " for i in range(df_cas.shape[0]):\n", " print(i)\n", - " fig, axs = double_mirrorplot(i) \n", - " plt.show()\n" + " fig, axs = double_mirrorplot(i)\n", + " plt.show()" ] }, { @@ -3266,7 +3656,10 @@ "source": [ "thecaseforfioracase_p = [6, 10, 24, 25, 33, 38, 49, 51, 52, 58, 88, 110]\n", "thecaseforfioracase = [x + 81 for x in thecaseforfioracase_p] + [3, 17]\n", - "weirdcases = [46, 88] # 46 symmetric -> issues withh pea display # 88 -> actually CFM-ID barely matches any of the high peaks (on high res)" + "weirdcases = [\n", + " 46,\n", + " 88,\n", + "] # 46 symmetric -> issues withh pea display # 88 -> actually CFM-ID barely matches any of the high peaks (on high res)" ] }, { @@ -3279,7 +3672,10 @@ "if saveFig:\n", " for i in thecaseforfioracase:\n", " fig, axs = double_mirrorplot(i)\n", - " plt.savefig(f\"{home}/images/\" + df_cas.at[i, \"ChallengeName\"] + \"_mirror.svg\", format=\"svg\")\n", + " plt.savefig(\n", + " f\"{home}/images/\" + df_cas.at[i, \"ChallengeName\"] + \"_mirror.svg\",\n", + " format=\"svg\",\n", + " )\n", " plt.clf()" ] }, @@ -3289,7 +3685,7 @@ "metadata": {}, "outputs": [], "source": [ - "# i = 33 + 81 \n", + "# i = 33 + 81\n", "# fig, axs = double_mirrorplot(i, model_title=\"GNN-based fragmentation\")\n", "# plt.savefig(f\"{home}/images/\" + df_cas.at[i, \"ChallengeName\"] + \"_gnn_mirror.svg\", format=\"svg\")" ] @@ -3301,7 +3697,7 @@ "outputs": [], "source": [ "# for i, mz in enumerate(df_cas.loc[88][\"peaks\"][\"mz\"]):\n", - "# if df_cas.loc[88][\"peaks\"][\"intensity\"][i] > 5000000: \n", + "# if df_cas.loc[88][\"peaks\"][\"intensity\"][i] > 5000000:\n", "# print(mz, df_cas.loc[88][\"peaks\"][\"intensity\"][i] / 1000000.0)\n", "# print(df_cas.loc[88][\"sim_peaks_35\"])\n", "# for i, mz in enumerate(df_cas.loc[88][\"sim_peaks_35\"][\"mz\"]):\n", @@ -3309,7 +3705,7 @@ "# print(df_cas.loc[88][\"cfm_peaks\"] )\n", "# for i, mz in enumerate(df_cas.loc[88][\"cfm_peaks\"][\"mz\"]):\n", "# if df_cas.loc[88][\"cfm_peaks\"][\"intensity\"][i] > 15.6: print(mz, df_cas.loc[88][\"cfm_peaks\"][\"intensity\"][i])\n", - " \n", + "\n", "\n", "# double_mirrorplot(88)\n", "# plt.show()\n", @@ -3342,14 +3738,14 @@ "# #num_categories = 10\n", "# #for lab in range(num_categories):\n", "# #indices = test_predictions==lab\n", - "# ax.scatter(tsne_proj[:,0],tsne_proj[:,1] ,alpha=0.5) #label = lab \n", + "# ax.scatter(tsne_proj[:,0],tsne_proj[:,1] ,alpha=0.5) #label = lab\n", "# #pep_indices = np.where(df_val[\"is_peptide\"])[0]\n", "# df_val[\"precursor_prob\"] = df[\"Metabolite\"].apply(lambda x: x.precursor_prob)\n", "# pep_indices = np.where(df_val[\"precursor_prob\"] > 0.9)[0]\n", "# #pep_indices = np.where(df_val[\"ring_proportion\"] > 0.8)[0]\n", "# #pep_indices = np.where(df_val[\"Precursor_type\"] == \"[M-H]-\")[0]\n", "# #pep_indices = np.where(np.logical_and(df_val[\"anyN\"], df_val[\"anyO\"]))[0]\n", - "# ax.scatter(tsne_proj[pep_indices,0],tsne_proj[pep_indices,1], c=\"red\" ,alpha=0.5) #label = lab \n", + "# ax.scatter(tsne_proj[pep_indices,0],tsne_proj[pep_indices,1], c=\"red\" ,alpha=0.5) #label = lab\n", "\n", "\n", "# #ax.scatter(tsne_proj[indices,0],tsne_proj[indices,1], c=np.array(cmap(lab)).reshape(1,4), label = lab ,alpha=0.5)\n", @@ -3535,9 +3931,13 @@ "# df_cas22[col] = df_cas22[col].apply(ast.literal_eval)\n", "\n", "print(df_cas22.shape)\n", - "df_cas22[\"ChallengeNum\"] = df_cas22[\"ChallengeName\"].apply(lambda x: int(x.split(\"-\")[-1]))\n", - "try: df_cas22.reset_index(inplace=True)\n", - "except: pass\n", + "df_cas22[\"ChallengeNum\"] = df_cas22[\"ChallengeName\"].apply(\n", + " lambda x: int(x.split(\"-\")[-1])\n", + ")\n", + "try:\n", + " df_cas22.reset_index(inplace=True)\n", + "except:\n", + " pass\n", "df_cas22.head(2)" ] }, @@ -3551,22 +3951,37 @@ "df_cas22[\"Metabolite\"] = df_cas22[\"SMILES\"].apply(Metabolite)\n", "df_cas22[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cas22[\"CE\"] = df_cas22.apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1)\n", + "df_cas22[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cas22[\"CE\"] = df_cas22.apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1\n", + ")\n", "\n", - "metadata_key_map = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"precursor_mz\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"ChallengeRT\"\n", - " }\n", + "metadata_key_map = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"precursor_mz\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"ChallengeRT\",\n", + "}\n", "\n", - "df_cas22[\"summary\"] = df_cas22.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df_cas22.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + "df_cas22[\"summary\"] = df_cas22.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", + ")\n", + "df_cas22.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "\n", "# Fragmentation\n", "df_cas22[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_cas22.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM), axis=1) # Optional: use mz_cut instead\n" + "df_cas22.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " ),\n", + " axis=1,\n", + ") # Optional: use mz_cut instead" ] }, { @@ -3588,8 +4003,8 @@ } ], "source": [ - "df_cas22_unique = df_cas22.drop_duplicates(subset='ChallengeName', keep='first')\n", - "#df_cas22_unique.reset_index(inplace=True)\n", + "df_cas22_unique = df_cas22.drop_duplicates(subset=\"ChallengeName\", keep=\"first\")\n", + "# df_cas22_unique.reset_index(inplace=True)\n", "df_cas22_unique[\"Metabolite\"] = df_cas22_unique[\"SMILES\"].apply(Metabolite)" ] }, @@ -3615,17 +4030,16 @@ "# xxx = []\n", "\n", "\n", - "\n", "# for i,d in df_cas22_unique.iterrows():\n", "# m = d[\"Metabolite\"]\n", - " \n", + "\n", "# for x,D in df.iterrows():\n", "# M = D[\"Metabolite\"]\n", "# if (m == M):\n", "# iii += [i]\n", "# xxx += [x]\n", - " \n", - "# iii = np.unique(iii) \n", + "\n", + "# iii = np.unique(iii)\n", "# print(iii)\n", "# print(f\"Found {len(iii)} instances violating test/train split. Metabolite found in train/val set.\")" ] @@ -3647,12 +4061,16 @@ } ], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", "\n", - "img = df_cas22.loc[0][\"Metabolite\"].draw(ax= axs[0])\n", + "img = df_cas22.loc[0][\"Metabolite\"].draw(ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False)\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", "axs[0].set_title(str(df_cas22.loc[0][\"Metabolite\"]))\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", @@ -3714,11 +4132,25 @@ "source": [ "fig, axs = plt.subplots(1, 2, figsize=(12, 6), sharey=True)\n", "\n", - "sns.histplot(ax=axs[0], data=df_cas22, x=\"CE\", hue=\"Precursor_type\", multiple=\"stack\", palette=[\"black\", \"gray\"])#bluepink[:2][::-1])\n", + "sns.histplot(\n", + " ax=axs[0],\n", + " data=df_cas22,\n", + " x=\"CE\",\n", + " hue=\"Precursor_type\",\n", + " multiple=\"stack\",\n", + " palette=[\"black\", \"gray\"],\n", + ") # bluepink[:2][::-1])\n", "axs[0].vlines(CE_upper_limit, 0, 120, color=\"red\")\n", "axs[0].set_xlabel(\"Collision energy\")\n", "\n", - "sns.histplot(ax=axs[1], data=df_cas22, x=\"precursor_mz\", hue=\"Precursor_type\", multiple=\"stack\", palette=[\"black\", \"gray\"])#bluepink[:2][::-1])\n", + "sns.histplot(\n", + " ax=axs[1],\n", + " data=df_cas22,\n", + " x=\"precursor_mz\",\n", + " hue=\"Precursor_type\",\n", + " multiple=\"stack\",\n", + " palette=[\"black\", \"gray\"],\n", + ") # bluepink[:2][::-1])\n", "# axs[1].vlines(800, 0, 120, color=\"red\")\n", "axs[1].set_xlabel(\"Precursor mz\")\n", "plt.show()" @@ -3741,7 +4173,11 @@ "outputs": [], "source": [ "df_cas22[\"coverage\"] = df_cas22[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - "df_cas22[\"ring_proportion\"] = df_cas22[\"Metabolite\"].apply(lambda x: (getattr(x, \"is_edge_in_ring\").sum() / getattr(x, \"is_edge_in_ring\").shape[0]).tolist())" + "df_cas22[\"ring_proportion\"] = df_cas22[\"Metabolite\"].apply(\n", + " lambda x: (\n", + " getattr(x, \"is_edge_in_ring\").sum() / getattr(x, \"is_edge_in_ring\").shape[0]\n", + " ).tolist()\n", + ")" ] }, { @@ -3750,9 +4186,24 @@ "metadata": {}, "outputs": [], "source": [ - "sns.boxplot(ax=axs[0], data=df_cas22, y=\"spectral_sqrt_cosine\", x=\"Precursor_type\", hue=\"NCE\", palette=bluepink_grad8[-3:][::-1])\n", + "sns.boxplot(\n", + " ax=axs[0],\n", + " data=df_cas22,\n", + " y=\"spectral_sqrt_cosine\",\n", + " x=\"Precursor_type\",\n", + " hue=\"NCE\",\n", + " palette=bluepink_grad8[-3:][::-1],\n", + ")\n", "axs[0].set_title(\"Spectral cosine similarity of sqrt intensities\")\n", - "sns.scatterplot(ax=axs[1], data=df_cas22, x=\"coverage\", y=\"spectral_sqrt_cosine\", hue=\"spectral_sqrt_cosine\", hue_norm=(0, 1), palette=bluepink_grad)\n", + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=df_cas22,\n", + " x=\"coverage\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " hue=\"spectral_sqrt_cosine\",\n", + " hue_norm=(0, 1),\n", + " palette=bluepink_grad,\n", + ")\n", "plt.show()" ] }, @@ -3783,17 +4234,25 @@ "i = 3\n", "print(df_cas22.loc[i][\"spectral_sqrt_cosine\"])\n", "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - "img = df_cas22.loc[i][\"Metabolite\"].draw(ax= axs[0])\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", + "img = df_cas22.loc[i][\"Metabolite\"].draw(ax=axs[0])\n", "\n", "axs[0].grid(False)\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", "axs[0].set_title(df_cas.loc[i][\"NAME\"])\n", "axs[0].imshow(img)\n", "axs[0].axis(\"off\")\n", - "#sv.plot_spectrum(example, ax=axs[1])\n", - "ax = sv.plot_spectrum({\"peaks\": df_cas22.loc[i][\"peaks\"]}, {\"peaks\": df_cas22.loc[i][\"sim_peaks\"]}, ax=axs[1])\n", - "#axs[1].text(0.5, 0.5, 'matplotlib', horizontalalignment='center', verticalalignment='center', transform=axs[1].transAxes)\n", + "# sv.plot_spectrum(example, ax=axs[1])\n", + "ax = sv.plot_spectrum(\n", + " {\"peaks\": df_cas22.loc[i][\"peaks\"]},\n", + " {\"peaks\": df_cas22.loc[i][\"sim_peaks\"]},\n", + " ax=axs[1],\n", + ")\n", + "# axs[1].text(0.5, 0.5, 'matplotlib', horizontalalignment='center', verticalalignment='center', transform=axs[1].transAxes)\n", "plt.show()" ] }, @@ -3814,15 +4273,34 @@ } ], "source": [ - "fig, axs = plt.subplots(2, 1, figsize=(12, 14), sharex=True, gridspec_kw={'height_ratios': [1, 5]})\n", - "plt.subplots_adjust(hspace=0.05)#right=0.975, left=0.11)\n", - "#sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=df_cas22, x=\"coverage\", bw_adjust=0.2, color=\"black\", multiple=\"stack\", hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(12, 14), sharex=True, gridspec_kw={\"height_ratios\": [1, 5]}\n", + ")\n", + "plt.subplots_adjust(hspace=0.05) # right=0.975, left=0.11)\n", + "# sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=df_cas22,\n", + " x=\"coverage\",\n", + " bw_adjust=0.2,\n", + " color=\"black\",\n", + " multiple=\"stack\",\n", + " hue=\"Precursor_type\",\n", + " palette=[\"black\", \"gray\"],\n", + ") # hue=\"Precursor_type\",\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", "\n", "axs[0].set_title(\"Impact of coverage on cosine scores\")\n", - "sns.scatterplot(ax=axs[1], data=df_cas22, x=\"coverage\", y=\"spectral_sqrt_cosine\", hue=\"spectral_sqrt_cosine\", hue_norm=(0, 1), palette=bluepink_grad)\n", + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=df_cas22,\n", + " x=\"coverage\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " hue=\"spectral_sqrt_cosine\",\n", + " hue_norm=(0, 1),\n", + " palette=bluepink_grad,\n", + ")\n", "axs[1].set_ylabel(\"Cosine similarity\")\n", "axs[1].set_xlabel(\"Peak intensity coverage\")\n", "plt.show()" @@ -3971,7 +4449,7 @@ "outputs": [], "source": [ "df_cas22[\"library\"] = \"CASMI-22\"\n", - "df_cas22[\"RETENTIONTIME\"] = df_cas22[\"ChallengeRT\"] # \"RT_min\"\n", + "df_cas22[\"RETENTIONTIME\"] = df_cas22[\"ChallengeRT\"] # \"RT_min\"\n", "df_cas22[\"cfm_CE\"] = df_cas22[\"CE\"].apply(closest_cfm_ce)" ] }, @@ -3985,9 +4463,12 @@ "# time CFM-ID 4: -> 12m16,571s\n", "\n", "\n", - "\n", - "cf22 = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi22_negative_predictions.txt\", as_df=True)\n", - "cf22_p = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi22_positive_predictions.txt\", as_df=True)\n", + "cf22 = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi22_negative_predictions.txt\", as_df=True\n", + ")\n", + "cf22_p = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi22_positive_predictions.txt\", as_df=True\n", + ")\n", "cf22 = pd.concat([cf22, cf22_p])" ] }, @@ -4015,15 +4496,18 @@ " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", " continue\n", " cfm_data = cf22[cf22[\"#ID\"] == challenge].iloc[0]\n", - " \n", - " \n", + "\n", " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", + " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", " df_cas22.at[i, \"cfm_peaks\"] = cfm_peaks\n", " df_cas22.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cas22.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", - " df_cas22.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", + " df_cas22.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", + " df_cas22.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", " df_cas22.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)" ] }, @@ -4053,9 +4537,23 @@ ], "source": [ "fig, axs = plt.subplots(1, 2, figsize=(18, 8.6), sharey=True)\n", - "sns.boxplot(ax=axs[0], data=df_cas22, y=\"spectral_sqrt_cosine\", x=\"Precursor_type\", hue=\"NCE\", palette=bluepink_grad8[-3:][::-1])\n", + "sns.boxplot(\n", + " ax=axs[0],\n", + " data=df_cas22,\n", + " y=\"spectral_sqrt_cosine\",\n", + " x=\"Precursor_type\",\n", + " hue=\"NCE\",\n", + " palette=bluepink_grad8[-3:][::-1],\n", + ")\n", "axs[0].set_title(\"Spectral cosine similarity of sqrt intensities\")\n", - "sns.boxplot(ax=axs[1], data=df_cas22, y=\"cfm_sqrt_cosine\", x=\"Precursor_type\", hue=\"NCE\", palette=bluepink_grad8)\n", + "sns.boxplot(\n", + " ax=axs[1],\n", + " data=df_cas22,\n", + " y=\"cfm_sqrt_cosine\",\n", + " x=\"Precursor_type\",\n", + " hue=\"NCE\",\n", + " palette=bluepink_grad8,\n", + ")\n", "\n", "plt.show()" ] @@ -4077,24 +4575,48 @@ } ], "source": [ - "dgnn = pd.DataFrame({\"cosine\": df_cas22[\"spectral_sqrt_cosine\"], \"refl_cosine\": df_cas22[\"spectral_refl_cosine\"], \"ion_mode\": df_cas22[\"Precursor_type\"], \"challenge_num\": df_cas22[\"ChallengeNum\"], \"model\": \"GNN-based fragmentation\"})\n", - "dcfm = pd.DataFrame({\"cosine\": df_cas22[\"cfm_sqrt_cosine\"], \"refl_cosine\": df_cas22[\"cfm_refl_cosine\"], \"ion_mode\": df_cas22[\"Precursor_type\"], \"challenge_num\": df_cas22[\"ChallengeNum\"], \"model\": \"CFM-ID 4.4.7\"})\n", + "dgnn = pd.DataFrame(\n", + " {\n", + " \"cosine\": df_cas22[\"spectral_sqrt_cosine\"],\n", + " \"refl_cosine\": df_cas22[\"spectral_refl_cosine\"],\n", + " \"ion_mode\": df_cas22[\"Precursor_type\"],\n", + " \"challenge_num\": df_cas22[\"ChallengeNum\"],\n", + " \"model\": \"GNN-based fragmentation\",\n", + " }\n", + ")\n", + "dcfm = pd.DataFrame(\n", + " {\n", + " \"cosine\": df_cas22[\"cfm_sqrt_cosine\"],\n", + " \"refl_cosine\": df_cas22[\"cfm_refl_cosine\"],\n", + " \"ion_mode\": df_cas22[\"Precursor_type\"],\n", + " \"challenge_num\": df_cas22[\"ChallengeNum\"],\n", + " \"model\": \"CFM-ID 4.4.7\",\n", + " }\n", + ")\n", "\n", "D = pd.concat([dgnn, dcfm], axis=0, ignore_index=True)\n", - "#D = pd.concat([dgnn[dgnn[\"challenge_num\"] < 250], dcfm[dcfm[\"challenge_num\"] <= 250]], axis=0)\n", + "# D = pd.concat([dgnn[dgnn[\"challenge_num\"] < 250], dcfm[dcfm[\"challenge_num\"] <= 250]], axis=0)\n", "\n", "\n", "fig, ax = plt.subplots(1, 1, figsize=(10, 6.6), sharey=False)\n", "\n", - "sns.boxplot(ax=ax, data=D, y=\"cosine\", x=\"ion_mode\", hue=\"model\", palette=bluepink[:2], order=[\"[M+H]+\", \"[M-H]-\"])\n", - "#sns.boxplot(ax=axs[0], data=D, y=\"cosine\", x=\"model\", hue=\"ion_mode\", palette=bluepink[:1] * 2 + bluepink[:1] * 2)\n", + "sns.boxplot(\n", + " ax=ax,\n", + " data=D,\n", + " y=\"cosine\",\n", + " x=\"ion_mode\",\n", + " hue=\"model\",\n", + " palette=bluepink[:2],\n", + " order=[\"[M+H]+\", \"[M-H]-\"],\n", + ")\n", + "# sns.boxplot(ax=axs[0], data=D, y=\"cosine\", x=\"model\", hue=\"ion_mode\", palette=bluepink[:1] * 2 + bluepink[:1] * 2)\n", "ax.set_title(\"Spectral cosine similarity of CASMI-22 predictions\")\n", "plt.xlabel(\"\")\n", "plt.legend(loc=\"lower right\")\n", - "plt.subplots_adjust(right=0.975, left=0.11) #TODO FIX error\n", + "plt.subplots_adjust(right=0.975, left=0.11) # TODO FIX error\n", "adjust_box_widths(fig, 0.95)\n", - "#set_all_font_sizes(18)\n", - "#plt.savefig(f\"{home}/images/cosine_boxplot.svg\", format=\"svg\")\n", + "# set_all_font_sizes(18)\n", + "# plt.savefig(f\"{home}/images/cosine_boxplot.svg\", format=\"svg\")\n", "plt.show()" ] }, @@ -4117,7 +4639,15 @@ "source": [ "fig, ax = plt.subplots(1, 1, figsize=(10, 6.6), sharey=False)\n", "\n", - "sns.boxplot(ax=ax, data=D, y=\"refl_cosine\", x=\"ion_mode\", hue=\"model\", palette=bluepink[:2], order=[\"[M+H]+\", \"[M-H]-\"])\n", + "sns.boxplot(\n", + " ax=ax,\n", + " data=D,\n", + " y=\"refl_cosine\",\n", + " x=\"ion_mode\",\n", + " hue=\"model\",\n", + " palette=bluepink[:2],\n", + " order=[\"[M+H]+\", \"[M-H]-\"],\n", + ")\n", "ax.set_title(\"Spectral cosine similarity of CASMI-22 predictions\")\n", "plt.xlabel(\"\")\n", "plt.legend(loc=\"lower right\")\n", @@ -4225,10 +4755,18 @@ ], "source": [ "score = \"spectral_sqrt_cosine\"\n", - "fiora_res = {\"model\": \"Fiora\", \"CASMI16\": np.median(df_cas[score.replace(\"spectral\", \"merged\")]), \"CASMI22\": np.median(df_cas22[score])} \n", - "cfm_id = {\"model\": \"CFM-ID 4.4.7\", \"CASMI16\": np.median(df_cas[score.replace(\"spectral\", \"cfm\")].dropna()), \"CASMI22\": np.median(df_cas22[score.replace(\"spectral\", \"cfm\")])} \n", + "fiora_res = {\n", + " \"model\": \"Fiora\",\n", + " \"CASMI16\": np.median(df_cas[score.replace(\"spectral\", \"merged\")]),\n", + " \"CASMI22\": np.median(df_cas22[score]),\n", + "}\n", + "cfm_id = {\n", + " \"model\": \"CFM-ID 4.4.7\",\n", + " \"CASMI16\": np.median(df_cas[score.replace(\"spectral\", \"cfm\")].dropna()),\n", + " \"CASMI22\": np.median(df_cas22[score.replace(\"spectral\", \"cfm\")]),\n", + "}\n", "\n", - "summary = pd.DataFrame( [fiora_res, cfm_id])\n", + "summary = pd.DataFrame([fiora_res, cfm_id])\n", "print(\"Summary test sets\")\n", "summary" ] @@ -4313,11 +4851,54 @@ ], "source": [ "score = \"spectral_sqrt_cosine\"\n", - "fiora_res = {\"model\": \"Fiora\", \"CASMI16+\": np.median(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"merged\")]), \"CASMI16-\":np.median(df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"merged\")]), \"CASMI16T+\": np.median(df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"merged\")]), \"CASMI16T-\":np.median(df_cast[df_cast[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"merged\")]), \"CASMI22+\": np.median(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score]), \"CASMI22-\": np.median(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score])} \n", + "fiora_res = {\n", + " \"model\": \"Fiora\",\n", + " \"CASMI16+\": np.median(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"merged\")\n", + " ]\n", + " ),\n", + " \"CASMI16-\": np.median(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"merged\")\n", + " ]\n", + " ),\n", + " \"CASMI16T+\": np.median(\n", + " df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"merged\")\n", + " ]\n", + " ),\n", + " \"CASMI16T-\": np.median(\n", + " df_cast[df_cast[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"merged\")\n", + " ]\n", + " ),\n", + " \"CASMI22+\": np.median(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score]),\n", + " \"CASMI22-\": np.median(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score]),\n", + "}\n", "cfm_id = {\n", - " \"model\": \"CFM-ID 4.4.7\", \"CASMI16+\": np.median(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")]), \"CASMI16-\": np.median(df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")].dropna()), \"CASMI22+\": np.median(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")]), \"CASMI22-\": np.median(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")])} \n", + " \"model\": \"CFM-ID 4.4.7\",\n", + " \"CASMI16+\": np.median(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")]\n", + " ),\n", + " \"CASMI16-\": np.median(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].dropna()\n", + " ),\n", + " \"CASMI22+\": np.median(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ]\n", + " ),\n", + " \"CASMI22-\": np.median(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ]\n", + " ),\n", + "}\n", "\n", - "summaryPos = pd.DataFrame( [fiora_res, cfm_id])\n", + "summaryPos = pd.DataFrame([fiora_res, cfm_id])\n", "print(\"Summary test sets\")\n", "summaryPos" ] @@ -4349,17 +4930,29 @@ "source": [ "raise KeyboardInterrupt()\n", "\n", - "df_cas.loc[:,\"tanimoto\"] = np.nan\n", - "for i,d in df_cas.iterrows():\n", - " df_cas.at[i, \"tanimoto\"] = df_train[\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"])).max()\n", - " \n", - "df_cas22.loc[:,\"tanimoto\"] = np.nan\n", - "for i,d in df_cas22.iterrows():\n", - " df_cas22.at[i, \"tanimoto\"] = df_train[\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"])).max()\n", - " \n", - "df_val.loc[:,\"tanimoto\"] = np.nan\n", - "for i,d in df_val.iterrows():\n", - " df_val.at[i, \"tanimoto\"] = df_train[\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"])).max()" + "df_cas.loc[:, \"tanimoto\"] = np.nan\n", + "for i, d in df_cas.iterrows():\n", + " df_cas.at[i, \"tanimoto\"] = (\n", + " df_train[\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " .max()\n", + " )\n", + "\n", + "df_cas22.loc[:, \"tanimoto\"] = np.nan\n", + "for i, d in df_cas22.iterrows():\n", + " df_cas22.at[i, \"tanimoto\"] = (\n", + " df_train[\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " .max()\n", + " )\n", + "\n", + "df_val.loc[:, \"tanimoto\"] = np.nan\n", + "for i, d in df_val.iterrows():\n", + " df_val.at[i, \"tanimoto\"] = (\n", + " df_train[\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " .max()\n", + " )" ] }, { @@ -4368,16 +4961,24 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots(1,1, figsize=(12, 6))\n", - "sns.boxplot(data=df_cas, x=pd.cut(df_cas['tanimoto'], bins=[x/10.0 for x in list(range(0,10,1))]), y=\"merged_sqrt_cosine\") #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", + "sns.boxplot(\n", + " data=df_cas,\n", + " x=pd.cut(df_cas[\"tanimoto\"], bins=[x / 10.0 for x in list(range(0, 10, 1))]),\n", + " y=\"merged_sqrt_cosine\",\n", + ") # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", "plt.ylim([0, 1])\n", "plt.show()\n", "\n", "\n", - "fig, ax = plt.subplots(1,1, figsize=(12, 6))\n", - "sns.boxplot(data=df_val, x=pd.cut(df_val['tanimoto'], bins=[x/10.0 for x in list(range(0,10,1))]), y=\"spectral_sqrt_cosine\") #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", + "sns.boxplot(\n", + " data=df_val,\n", + " x=pd.cut(df_val[\"tanimoto\"], bins=[x / 10.0 for x in list(range(0, 10, 1))]),\n", + " y=\"spectral_sqrt_cosine\",\n", + ") # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", "plt.ylim([0, 1])\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -4386,9 +4987,17 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots(1,1, figsize=(12, 6))\n", + "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", "C = pd.concat([df_val, df_cas, df_cas22])\n", - "sns.pointplot(data=C, x=pd.cut(C['tanimoto'], bins=[x/10.0 for x in list(range(0,10,1))]), y=\"spectral_sqrt_cosine\", palette=tri_palette, capsize=.2, hue=\"test_set\", dodge=0.3) #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + "sns.pointplot(\n", + " data=C,\n", + " x=pd.cut(C[\"tanimoto\"], bins=[x / 10.0 for x in list(range(0, 10, 1))]),\n", + " y=\"spectral_sqrt_cosine\",\n", + " palette=tri_palette,\n", + " capsize=0.2,\n", + " hue=\"test_set\",\n", + " dodge=0.3,\n", + ") # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", "plt.ylim([0, 1])\n", "plt.show()" ] @@ -4409,8 +5018,8 @@ "source": [ "import scipy\n", "\n", - "df_val_unique = df_val.drop_duplicates(subset='SMILES', keep='first')\n", - "df_cas22_unique = df_cas22.drop_duplicates(subset='SMILES', keep='first')\n", + "df_val_unique = df_val.drop_duplicates(subset=\"SMILES\", keep=\"first\")\n", + "df_cas22_unique = df_cas22.drop_duplicates(subset=\"SMILES\", keep=\"first\")\n", "\n", "#\n", "# Recalibration\n", @@ -4419,14 +5028,26 @@ "\n", "df_val_unique[\"RT_pred_cal\"] = df_val_unique[\"RT_pred\"]\n", "\n", - "confident_ids = df_cas.sort_values(by=\"merged_sqrt_cosine\", ascending=False).dropna(subset=[\"RETENTIONTIME\"]).head(10)\n", - "rt_slope, rt_intercept, r_value, p_value, std_err = scipy.stats.linregress(confident_ids[\"RT_pred\"].astype(float), confident_ids[\"RETENTIONTIME\"])\n", - "df_cas[\"RT_pred_cal\"] = rt_intercept + df_cas[\"RT_pred\"] * rt_slope \n", + "confident_ids = (\n", + " df_cas.sort_values(by=\"merged_sqrt_cosine\", ascending=False)\n", + " .dropna(subset=[\"RETENTIONTIME\"])\n", + " .head(10)\n", + ")\n", + "rt_slope, rt_intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", + " confident_ids[\"RT_pred\"].astype(float), confident_ids[\"RETENTIONTIME\"]\n", + ")\n", + "df_cas[\"RT_pred_cal\"] = rt_intercept + df_cas[\"RT_pred\"] * rt_slope\n", "\n", "\n", - "confident_ids = df_cas22_unique.sort_values(by=\"spectral_sqrt_cosine\", ascending=False).dropna(subset=[\"RETENTIONTIME\"]).head(10)\n", - "rt_slope, rt_intercept, r_value, p_value, std_err = scipy.stats.linregress(confident_ids[\"RT_pred\"].astype(float), confident_ids[\"RETENTIONTIME\"])\n", - "df_cas22_unique[\"RT_pred_cal\"] = rt_intercept + df_cas22_unique[\"RT_pred\"] * rt_slope " + "confident_ids = (\n", + " df_cas22_unique.sort_values(by=\"spectral_sqrt_cosine\", ascending=False)\n", + " .dropna(subset=[\"RETENTIONTIME\"])\n", + " .head(10)\n", + ")\n", + "rt_slope, rt_intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", + " confident_ids[\"RT_pred\"].astype(float), confident_ids[\"RETENTIONTIME\"]\n", + ")\n", + "df_cas22_unique[\"RT_pred_cal\"] = rt_intercept + df_cas22_unique[\"RT_pred\"] * rt_slope" ] }, { @@ -4435,7 +5056,16 @@ "metadata": {}, "outputs": [], "source": [ - "RT = pd.concat([df_val_unique[[\"RETENTIONTIME\", \"RT_pred\", \"RT_pred_cal\", \"RT_dif\", \"library\"]], df_cas[[\"RETENTIONTIME\", \"RT_pred\", \"RT_pred_cal\", \"RT_dif\", \"library\"]], df_cas22_unique[[\"RETENTIONTIME\", \"RT_pred\", \"RT_pred_cal\", \"RT_dif\", \"library\"]]], ignore_index=True)\n", + "RT = pd.concat(\n", + " [\n", + " df_val_unique[[\"RETENTIONTIME\", \"RT_pred\", \"RT_pred_cal\", \"RT_dif\", \"library\"]],\n", + " df_cas[[\"RETENTIONTIME\", \"RT_pred\", \"RT_pred_cal\", \"RT_dif\", \"library\"]],\n", + " df_cas22_unique[\n", + " [\"RETENTIONTIME\", \"RT_pred\", \"RT_pred_cal\", \"RT_dif\", \"library\"]\n", + " ],\n", + " ],\n", + " ignore_index=True,\n", + ")\n", "RT = RT.dropna(subset=[\"RETENTIONTIME\"])\n", "RT[\"RT_rel_dif\"] = RT[\"RT_dif\"] / RT[\"RETENTIONTIME\"]" ] @@ -4446,18 +5076,30 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(12, 8), gridspec_kw={'width_ratios': [5, 2]}, sharey=False)\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12, 8), gridspec_kw={\"width_ratios\": [5, 2]}, sharey=False\n", + ")\n", "plt.subplots_adjust(wspace=0.1)\n", "\n", - "sns.scatterplot(ax=axs[0], data=RT, x=\"RETENTIONTIME\", y=\"RT_pred_cal\", hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", - "axs[0].set_ylim([0,30])\n", - "axs[0].set_xlim([0,30])\n", + "sns.scatterplot(\n", + " ax=axs[0],\n", + " data=RT,\n", + " x=\"RETENTIONTIME\",\n", + " y=\"RT_pred_cal\",\n", + " hue=\"library\",\n", + " palette=tri_palette,\n", + " style=\"library\",\n", + " color=\"gray\",\n", + ")\n", + "axs[0].set_ylim([0, 30])\n", + "axs[0].set_xlim([0, 30])\n", "axs[0].set_ylabel(\"Predicted retention time\")\n", "axs[0].set_xlabel(\"Observed retention time\")\n", - "sns.lineplot(ax=axs[0], x=[0,100], y=[0,100], color=\"black\")\n", + "sns.lineplot(ax=axs[0], x=[0, 100], y=[0, 100], color=\"black\")\n", "\n", - "sns.boxplot(ax=axs[1], data=RT, y=\"RT_dif\", palette=tri_palette, x=\"library\", showfliers=False)\n", + "sns.boxplot(\n", + " ax=axs[1], data=RT, y=\"RT_dif\", palette=tri_palette, x=\"library\", showfliers=False\n", + ")\n", "axs[1].set_xlabel(\"\")\n", "axs[1].set_ylabel(\"RT difference (in minutes)\")\n", "axs[1].yaxis.set_label_position(\"right\")\n", @@ -4471,8 +5113,20 @@ "metadata": {}, "outputs": [], "source": [ - "print(np.corrcoef(RT[\"RETENTIONTIME\"].values, RT[\"RT_pred_cal\"].values, dtype=float)[0,1])\n", - "print(np.corrcoef(df_val_unique.dropna(subset=[\"RETENTIONTIME\"])[\"RETENTIONTIME\"], df_val_unique.dropna(subset=[\"RETENTIONTIME\"])[\"RT_pred\"], dtype=float)[0,1], np.corrcoef(df_cas[\"RETENTIONTIME\"], df_cas[\"RT_pred\"], dtype=float)[0,1], np.corrcoef(df_cas22_unique[\"RETENTIONTIME\"], df_cas22_unique[\"RT_pred\"], dtype=float)[0,1])" + "print(\n", + " np.corrcoef(RT[\"RETENTIONTIME\"].values, RT[\"RT_pred_cal\"].values, dtype=float)[0, 1]\n", + ")\n", + "print(\n", + " np.corrcoef(\n", + " df_val_unique.dropna(subset=[\"RETENTIONTIME\"])[\"RETENTIONTIME\"],\n", + " df_val_unique.dropna(subset=[\"RETENTIONTIME\"])[\"RT_pred\"],\n", + " dtype=float,\n", + " )[0, 1],\n", + " np.corrcoef(df_cas[\"RETENTIONTIME\"], df_cas[\"RT_pred\"], dtype=float)[0, 1],\n", + " np.corrcoef(\n", + " df_cas22_unique[\"RETENTIONTIME\"], df_cas22_unique[\"RT_pred\"], dtype=float\n", + " )[0, 1],\n", + ")" ] }, { @@ -4481,25 +5135,42 @@ "metadata": {}, "outputs": [], "source": [ - "fig, axs = plt.subplots(2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True)\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", "\n", - "#sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=RT, x=\"RETENTIONTIME\", bw_adjust=0.25, color=\"black\", multiple=\"layer\", hue=\"library\", palette=tri_palette) #hue=\"Precursor_type\", \n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", - "\n", - "\n", - "sns.scatterplot(ax=axs[1], data=df_val_unique, x=\"RETENTIONTIME\", y=\"RT_pred_cal\", color=\"gray\")#, hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", - "axs[1].set_ylim([0,df_val_unique[\"RETENTIONTIME\"].max() + 1 ])\n", - "axs[1].set_xlim([0,df_val_unique[\"RETENTIONTIME\"].max() + 1])\n", + "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=RT,\n", + " x=\"RETENTIONTIME\",\n", + " bw_adjust=0.25,\n", + " color=\"black\",\n", + " multiple=\"layer\",\n", + " hue=\"library\",\n", + " palette=tri_palette,\n", + ") # hue=\"Precursor_type\",\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", + "\n", + "\n", + "sns.scatterplot(\n", + " ax=axs[1], data=df_val_unique, x=\"RETENTIONTIME\", y=\"RT_pred_cal\", color=\"gray\"\n", + ") # , hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", + "axs[1].set_ylim([0, df_val_unique[\"RETENTIONTIME\"].max() + 1])\n", + "axs[1].set_xlim([0, df_val_unique[\"RETENTIONTIME\"].max() + 1])\n", "axs[1].set_ylabel(\"Predicted retention time\")\n", "axs[1].set_xlabel(\"Observed retention time\")\n", "line = [0, 100]\n", "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", - "sns.lineplot(ax=axs[1], x=line, y=[x + 20/60.0 for x in line], color=\"black\", linestyle='--')\n", - "sns.lineplot(ax=axs[1], x=line, y=[x - 20/60.0 for x in line], color=\"black\", linestyle='--')" + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[x + 20 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[x - 20 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + ")" ] }, { @@ -4522,31 +5193,55 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "fig, axs = plt.subplots(2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True)\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", "# df_val_unique = df_val.drop_duplicates(subset='SMILES', keep='first')\n", "# df_cas22_unique = df_cas22.drop_duplicates(subset='SMILES', keep='first')\n", "\n", - "CCS = pd.concat([df_val_unique[[\"CCS\", \"CCS_pred\", \"library\"]], df_cas[[\"CCS\", \"CCS_pred\", \"library\"]], df_cas22_unique[[\"CCS\", \"CCS_pred\", \"library\"]] ], ignore_index=True)\n", - "\n", - "\n", - "#sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=CCS, x=\"CCS\", bw_adjust=0.25, color=\"black\", multiple=\"stack\") #hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", - "\n", - "\n", - "sns.scatterplot(ax=axs[1], data=CCS, x=\"CCS\", y=\"CCS_pred\", s=25, hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")#, color=\"blue\", edgecolor=\"blue\")#, \n", - "axs[1].set_ylim([df_val_unique[\"CCS\"].min() - 10,df_val_unique[\"CCS\"].max() + 10])\n", - "axs[1].set_xlim([df_val_unique[\"CCS\"].min() - 10,df_val_unique[\"CCS\"].max() + 10])\n", + "CCS = pd.concat(\n", + " [\n", + " df_val_unique[[\"CCS\", \"CCS_pred\", \"library\"]],\n", + " df_cas[[\"CCS\", \"CCS_pred\", \"library\"]],\n", + " df_cas22_unique[[\"CCS\", \"CCS_pred\", \"library\"]],\n", + " ],\n", + " ignore_index=True,\n", + ")\n", + "\n", + "\n", + "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0], data=CCS, x=\"CCS\", bw_adjust=0.25, color=\"black\", multiple=\"stack\"\n", + ") # hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", + "\n", + "\n", + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=CCS,\n", + " x=\"CCS\",\n", + " y=\"CCS_pred\",\n", + " s=25,\n", + " hue=\"library\",\n", + " palette=tri_palette,\n", + " style=\"library\",\n", + " color=\"gray\",\n", + ") # , color=\"blue\", edgecolor=\"blue\")#,\n", + "axs[1].set_ylim([df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10])\n", + "axs[1].set_xlim([df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10])\n", "axs[1].set_ylabel(\"Predicted CCS\")\n", "axs[1].set_xlabel(\"Observed CCS\")\n", - "line=[df_val_unique[\"CCS\"].min() - 10,df_val_unique[\"CCS\"].max() + 10]\n", + "line = [df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10]\n", "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", - "sns.lineplot(ax=axs[1], x=line, y=[1.1*x for x in line], color=\"black\", linestyle='--')\n", - "sns.lineplot(ax=axs[1], x=line, y=[0.9*x for x in line], color=\"black\", linestyle='--')\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[1.1 * x for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[0.9 * x for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", "plt.show()" ] }, @@ -4574,31 +5269,55 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "fig, axs = plt.subplots(2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True)\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", "# df_val_unique = df_val.drop_duplicates(subset='SMILES', keep='first')\n", "# df_cas22_unique = df_cas22.drop_duplicates(subset='SMILES', keep='first')\n", "\n", - "CCS = pd.concat([df_val_unique[[\"CCS\", \"CCS_pred\", \"library\"]], df_cas[[\"CCS\", \"CCS_pred\", \"library\"]], df_cas22_unique[[\"CCS\", \"CCS_pred\", \"library\"]] ], ignore_index=True)\n", - "\n", - "\n", - "#sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=CCS, x=\"CCS\", bw_adjust=0.25, color=\"black\", multiple=\"stack\") #hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", - "\n", - "\n", - "sns.scatterplot(ax=axs[1], data=CCS, x=\"CCS\", y=\"CCS_pred\", s=25, hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")#, color=\"blue\", edgecolor=\"blue\")#, \n", - "axs[1].set_ylim([df_val_unique[\"CCS\"].min() - 10,df_val_unique[\"CCS\"].max() + 10])\n", - "axs[1].set_xlim([df_val_unique[\"CCS\"].min() - 10,df_val_unique[\"CCS\"].max() + 10])\n", + "CCS = pd.concat(\n", + " [\n", + " df_val_unique[[\"CCS\", \"CCS_pred\", \"library\"]],\n", + " df_cas[[\"CCS\", \"CCS_pred\", \"library\"]],\n", + " df_cas22_unique[[\"CCS\", \"CCS_pred\", \"library\"]],\n", + " ],\n", + " ignore_index=True,\n", + ")\n", + "\n", + "\n", + "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0], data=CCS, x=\"CCS\", bw_adjust=0.25, color=\"black\", multiple=\"stack\"\n", + ") # hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", + "\n", + "\n", + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=CCS,\n", + " x=\"CCS\",\n", + " y=\"CCS_pred\",\n", + " s=25,\n", + " hue=\"library\",\n", + " palette=tri_palette,\n", + " style=\"library\",\n", + " color=\"gray\",\n", + ") # , color=\"blue\", edgecolor=\"blue\")#,\n", + "axs[1].set_ylim([df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10])\n", + "axs[1].set_xlim([df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10])\n", "axs[1].set_ylabel(\"Predicted CCS\")\n", "axs[1].set_xlabel(\"Observed CCS\")\n", - "line=[df_val_unique[\"CCS\"].min() - 10,df_val_unique[\"CCS\"].max() + 10]\n", + "line = [df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10]\n", "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", - "sns.lineplot(ax=axs[1], x=line, y=[1.1*x for x in line], color=\"black\", linestyle='--')\n", - "sns.lineplot(ax=axs[1], x=line, y=[0.9*x for x in line], color=\"black\", linestyle='--')\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[1.1 * x for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[0.9 * x for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", "plt.show()" ] }, @@ -4609,13 +5328,43 @@ "outputs": [], "source": [ "print(\"Pearson Corr Coef:\")\n", - "print(\"GNN\", np.corrcoef(df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"], df_val_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(), dtype=float)[0,1])\n", - "print(\"pMZ\", np.corrcoef(df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"], df_val_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(), dtype=float)[0,1])\n", - "\n", - "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(df_train.dropna(subset=[\"CCS\"])[\"PRECURSORMZ\"], df_train.dropna(subset=[\"CCS\"])[\"CCS\"])\n", + "print(\n", + " \"GNN\",\n", + " np.corrcoef(\n", + " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", + "print(\n", + " \"pMZ\",\n", + " np.corrcoef(\n", + " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_val_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", + "\n", + "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", + " df_train.dropna(subset=[\"CCS\"])[\"PRECURSORMZ\"],\n", + " df_train.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + ")\n", "print(\"R2\")\n", - "print(\"GNN\", r2_score(df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"], df_val_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna()))\n", - "print(\"pMZ\", r2_score(df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"], intercept + slope *df_val_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna()))" + "print(\n", + " \"GNN\",\n", + " r2_score(\n", + " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " ),\n", + ")\n", + "print(\n", + " \"pMZ\",\n", + " r2_score(\n", + " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " intercept\n", + " + slope * df_val_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " ),\n", + ")" ] }, { @@ -4625,13 +5374,42 @@ "outputs": [], "source": [ "print(\"Pearson Corr Coef:\")\n", - "print(\"GNN\", np.corrcoef(df_cas.dropna(subset=[\"CCS\"])[\"CCS\"], df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(), dtype=float)[0,1])\n", - "print(\"pMZ\", np.corrcoef(df_cas.dropna(subset=[\"CCS\"])[\"CCS\"], df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(), dtype=float)[0,1])\n", - "\n", - "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(df_train.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"], df_train.dropna(subset=[\"CCS\"])[\"CCS\"])\n", + "print(\n", + " \"GNN\",\n", + " np.corrcoef(\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", + "print(\n", + " \"pMZ\",\n", + " np.corrcoef(\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", + "\n", + "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", + " df_train.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"],\n", + " df_train.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + ")\n", "print(\"R2\")\n", - "print(\"GNN\", r2_score(df_cas.dropna(subset=[\"CCS\"])[\"CCS\"], df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna()))\n", - "print(\"pMZ\", r2_score(df_cas.dropna(subset=[\"CCS\"])[\"CCS\"], intercept + slope *df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna()))" + "print(\n", + " \"GNN\",\n", + " r2_score(\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " ),\n", + ")\n", + "print(\n", + " \"pMZ\",\n", + " r2_score(\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " intercept + slope * df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " ),\n", + ")" ] }, { @@ -4641,7 +5419,7 @@ "outputs": [], "source": [ "# names = df_cas[\"NAME\"]\n", - "# print(df_cas[[\"NAME\", \"RTINSECONDS\", \"RETENTIONTIME\", \"RT_pred\", \"ChallengeName\"]]) \n", + "# print(df_cas[[\"NAME\", \"RTINSECONDS\", \"RETENTIONTIME\", \"RT_pred\", \"ChallengeName\"]])\n", "# # TODO which RT to use for CASMI22 -> ChallengeRT" ] }, @@ -4680,7 +5458,7 @@ "# for i, d in df_cas.iterrows():\n", "# n = d[\"NAME\"]\n", "# rt_train = ff[ff[\"NAME\"] == n][\"RETENTIONTIME\"].mean()\n", - " \n", + "\n", "# if abs(d[\"RETENTIONTIME\"] - rt_train) > 1:\n", "# c+=1\n", "# c" @@ -4705,30 +5483,56 @@ "outputs": [], "source": [ "def double_mirrorplot22(i, model_title=\"Fiora\"):\n", - " fig, axs = plt.subplots(1, 3, figsize=(16.8, 4.2), gridspec_kw={'width_ratios': [1, 3, 3]}, sharey=False)\n", - " \n", + " fig, axs = plt.subplots(\n", + " 1, 3, figsize=(16.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3, 3]}, sharey=False\n", + " )\n", + "\n", " plt.subplots_adjust(right=0.975, left=0.025)\n", - " \n", - " img = df_cas22.iloc[i][\"Metabolite\"].draw(ax= axs[0])\n", + "\n", + " img = df_cas22.iloc[i][\"Metabolite\"].draw(ax=axs[0])\n", "\n", " axs[0].grid(False)\n", - " axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", + " axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " )\n", " axs[0].set_title(df_cas22.iloc[i][\"ChallengeName\"])\n", " axs[0].imshow(img)\n", " axs[0].axis(\"off\")\n", "\n", - " sv.plot_spectrum({\"peaks\": df_cas22.iloc[i][\"peaks\"]}, {\"peaks\": df_cas22.iloc[i][\"sim_peaks\"]}, ax=axs[1])\n", + " sv.plot_spectrum(\n", + " {\"peaks\": df_cas22.iloc[i][\"peaks\"]},\n", + " {\"peaks\": df_cas22.iloc[i][\"sim_peaks\"]},\n", + " ax=axs[1],\n", + " )\n", " axs[1].title.set_text(model_title)\n", - " patch1 = mpatches.Patch(color='limegreen' if df_cas22.iloc[i][\"cfm_sqrt_cosine\"] < df_cas22.iloc[i][\"spectral_sqrt_cosine\"] else \"orangered\", label=f'cosine {df_cas22.iloc[i][\"spectral_sqrt_cosine\"]:.02f}')\n", + " patch1 = mpatches.Patch(\n", + " color=\"limegreen\"\n", + " if df_cas22.iloc[i][\"cfm_sqrt_cosine\"]\n", + " < df_cas22.iloc[i][\"spectral_sqrt_cosine\"]\n", + " else \"orangered\",\n", + " label=f\"cosine {df_cas22.iloc[i]['spectral_sqrt_cosine']:.02f}\",\n", + " )\n", " axs[1].legend(handles=[patch1])\n", "\n", - " sv.plot_spectrum({\"peaks\": df_cas22.iloc[i][\"peaks\"]}, {\"peaks\": df_cas22.iloc[i][\"cfm_peaks\"]} if df_cas22.iloc[i][\"cfm_peaks\"] else {\"peaks\": {\"mz\": [0], \"intensity\": [0]}}, ax=axs[2])\n", - " axs[2].title.set_text(f'CFM-ID 4.4.7')\n", - " \n", - " patch2 = mpatches.Patch(color='limegreen' if df_cas22.iloc[i][\"cfm_sqrt_cosine\"] > df_cas22.iloc[i][\"spectral_sqrt_cosine\"] else \"orangered\", label=f'cosine {df_cas22.iloc[i][\"cfm_sqrt_cosine\"]:.02f}', )\n", + " sv.plot_spectrum(\n", + " {\"peaks\": df_cas22.iloc[i][\"peaks\"]},\n", + " {\"peaks\": df_cas22.iloc[i][\"cfm_peaks\"]}\n", + " if df_cas22.iloc[i][\"cfm_peaks\"]\n", + " else {\"peaks\": {\"mz\": [0], \"intensity\": [0]}},\n", + " ax=axs[2],\n", + " )\n", + " axs[2].title.set_text(f\"CFM-ID 4.4.7\")\n", + "\n", + " patch2 = mpatches.Patch(\n", + " color=\"limegreen\"\n", + " if df_cas22.iloc[i][\"cfm_sqrt_cosine\"]\n", + " > df_cas22.iloc[i][\"spectral_sqrt_cosine\"]\n", + " else \"orangered\",\n", + " label=f\"cosine {df_cas22.iloc[i]['cfm_sqrt_cosine']:.02f}\",\n", + " )\n", " axs[2].legend(handles=[patch2])\n", - " \n", - " return fig, axs\n" + "\n", + " return fig, axs" ] }, { @@ -4743,7 +5547,7 @@ "# for i in range(df_cas22.shape[0]):\n", "# print(i)\n", "# if plot_examples:\n", - "# fig, axs = double_mirrorplot22(i) \n", + "# fig, axs = double_mirrorplot22(i)\n", "# plt.show()\n", "# if df_cas22.iloc[i][\"cfm_sqrt_cosine\"] > df_cas22.iloc[i][\"spectral_sqrt_cosine\"]:\n", "# lower += 1\n", @@ -4764,7 +5568,9 @@ "metadata": {}, "outputs": [], "source": [ - "df_val[\"pp\"] = df_val[\"Metabolite\"].apply(lambda x: (x.precursor_count / (sum(x.compiled_countsALL) / 2.0)).tolist() )" + "df_val[\"pp\"] = df_val[\"Metabolite\"].apply(\n", + " lambda x: (x.precursor_count / (sum(x.compiled_countsALL) / 2.0)).tolist()\n", + ")" ] }, { @@ -4773,7 +5579,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.histplot(df_val, x=\"pp\", hue=\"PRECURSORTYPE\", multiple=\"dodge\", common_norm=False, stat=\"density\")" + "sns.histplot(\n", + " df_val,\n", + " x=\"pp\",\n", + " hue=\"PRECURSORTYPE\",\n", + " multiple=\"dodge\",\n", + " common_norm=False,\n", + " stat=\"density\",\n", + ")" ] }, { @@ -4782,7 +5595,9 @@ "metadata": {}, "outputs": [], "source": [ - "sns.kdeplot(data=df_val, x=\"pp\", y=\"CE\", hue=\"Precursor_type\") #multiple=\"dodge\", common_norm=False, stat=\"density\")" + "sns.kdeplot(\n", + " data=df_val, x=\"pp\", y=\"CE\", hue=\"Precursor_type\"\n", + ") # multiple=\"dodge\", common_norm=False, stat=\"density\")" ] }, { @@ -4793,7 +5608,7 @@ "source": [ "import umap\n", "\n", - "reducer = umap.UMAP(random_state=seed, n_neighbors=10, min_dist= 0.01)\n", + "reducer = umap.UMAP(random_state=seed, n_neighbors=10, min_dist=0.01)\n", "reducer.fit(ms2ds_vectors)" ] }, @@ -4841,7 +5656,7 @@ "metadata": {}, "outputs": [], "source": [ - "#df.to_csv(f'{home}/data/metabolites/preprocessed/datasplits_Jan24.csv')" + "# df.to_csv(f'{home}/data/metabolites/preprocessed/datasplits_Jan24.csv')" ] } ], diff --git a/notebooks/grid_search.ipynb b/notebooks/grid_search.ipynb index 69a8100..d4e58c6 100644 --- a/notebooks/grid_search.ipynb +++ b/notebooks/grid_search.ipynb @@ -27,7 +27,7 @@ "import torch\n", "\n", "seed = 42\n", - "#torch.set_default_dtype(torch.float64)\n", + "# torch.set_default_dtype(torch.float64)\n", "torch.manual_seed(seed)\n", "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", @@ -40,18 +40,20 @@ "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree \n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", "import fiora.visualization.spectrum_visualizer as sv\n", "\n", "from sklearn.metrics import r2_score\n", "import scipy\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f'Working with Python {sys.version}')\n" + "RDLogger.DisableLog(\"rdApp.*\")\n", + "\n", + "print(f\"Working with Python {sys.version}\")" ] }, { @@ -77,12 +79,15 @@ ], "source": [ "from typing import Literal\n", + "\n", "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\"] = \"NIST/MSDIAL\"\n", "print(f\"Preparing {lib} library\")\n", "\n", - "test_run = False # Default: False\n", + "test_run = False # Default: False\n", "if test_run:\n", - " print(\"+++ This is a test run with a small subset of data points. Results are not representative. +++\")" + " print(\n", + " \"+++ This is a test run with a small subset of data points. Results are not representative. +++\"\n", + " )" ] }, { @@ -93,35 +98,37 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\"\n", - " }\n", + " \"name\": \"Name\",\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"ionization\": \"Ionization\",\n", + " \"precursor_mz\": \"PrecursorMZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + " \"ccs\": \"CCS\",\n", + "}\n", "\n", "\n", "#\n", "# Load specified libraries and align metadata\n", "#\n", "\n", + "\n", "def load_training_data():\n", " L = LibraryLoader()\n", " df = L.load_from_csv(f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\")\n", " return df\n", "\n", + "\n", "df = load_training_data()\n", "\n", "# Restore dictionary values\n", "dict_columns = [\"peaks\", \"summary\"]\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", - " #df[col] = df[col].apply(ast.literal_eval)\n", - " \n", - "df['group_id'] = df['group_id'].astype(int)\n" + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " # df[col] = df[col].apply(ast.literal_eval)\n", + "\n", + "df[\"group_id\"] = df[\"group_id\"].astype(int)" ] }, { @@ -134,7 +141,7 @@ "# TODO: ATTENTION\n", "#\n", "\n", - "df[\"ppm_peak_tolerance\"] = 100 * PPM\n" + "df[\"ppm_peak_tolerance\"] = 100 * PPM" ] }, { @@ -155,8 +162,8 @@ "\n", "\n", "if test_run:\n", - " df = df.iloc[:10000,:]\n", - " #df = df.iloc[5000:20000,:]\n", + " df = df.iloc[:10000, :]\n", + " # df = df.iloc[5000:20000,:]\n", "\n", "\n", "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", @@ -164,18 +171,30 @@ "\n", "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", - "setup_encoder = SetupFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\"])\n", - "rt_encoder = SetupFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"])\n", + "setup_encoder = SetupFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " ]\n", + ")\n", + "rt_encoder = SetupFeatureEncoder(\n", + " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"]\n", + ")\n", "\n", - "setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - "setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", + "setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + "setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", "\n", "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]) , axis=1)\n", + "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", "\n", - "#df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", + "df.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "df[\"num_per_group\"] = df[\"group_id\"].map(df[\"group_id\"].value_counts())\n", "df[\"loss_weight\"] = 1.0 / df[\"num_per_group\"]\n", "df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)" @@ -189,7 +208,12 @@ "source": [ "%%capture\n", "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]), axis=1)" + "df.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -218,6 +242,7 @@ "\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", + "\n", "sns.violinplot(df, y=\"coverage\", hue=\"lib\", split=True)\n", "plt.show()" ] @@ -251,7 +276,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = df[df[\"coverage\"] > 0.5] #TODO: ATTENTION 50 ppm + 50% cov cutoff" + "df = df[df[\"coverage\"] > 0.5] # TODO: ATTENTION 50 ppm + 50% cov cutoff" ] }, { @@ -317,23 +342,35 @@ "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", "df_cas[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "df_cas[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", + "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", "\n", - "metadata_key_map16 = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\"\n", - " }\n", + "metadata_key_map16 = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"PRECURSOR_MZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + "}\n", "\n", - "df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - "df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1)\n", + "df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + ")\n", + "df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1\n", + ")\n", "\n", "# Fragmentation\n", "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_cas.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM), axis=1) # Optional: use mz_cut instead\n", + "df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " ),\n", + " axis=1,\n", + ") # Optional: use mz_cut instead\n", "\n", "#\n", "# CASMI 22\n", @@ -342,22 +379,37 @@ "df_cas22[\"Metabolite\"] = df_cas22[\"SMILES\"].apply(Metabolite)\n", "df_cas22[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cas22[\"CE\"] = df_cas22.apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1)\n", + "df_cas22[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cas22[\"CE\"] = df_cas22.apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1\n", + ")\n", "\n", - "metadata_key_map22 = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"precursor_mz\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"ChallengeRT\"\n", - " }\n", + "metadata_key_map22 = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"precursor_mz\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"ChallengeRT\",\n", + "}\n", "\n", - "df_cas22[\"summary\"] = df_cas22.apply(lambda x: {key: x[name] for key, name in metadata_key_map22.items()}, axis=1)\n", - "df_cas22.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + "df_cas22[\"summary\"] = df_cas22.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map22.items()}, axis=1\n", + ")\n", + "df_cas22.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "\n", "# Fragmentation\n", "df_cas22[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_cas22.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM), axis=1) # Optional: use mz_cut instead\n", + "df_cas22.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " ),\n", + " axis=1,\n", + ") # Optional: use mz_cut instead\n", "\n", "df_cas22 = df_cas22.reset_index()" ] @@ -387,15 +439,13 @@ "from fiora.GNN.Trainer import Trainer\n", "import torch_geometric as geom\n", "\n", - "if torch.cuda.is_available(): \n", - " torch.cuda.empty_cache()\n", - " dev = \"cuda:1\"\n", - "else: \n", - " dev = \"cpu\" \n", - " \n", - "print(f\"Running on device: {dev}\")\n", + "if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + " dev = \"cuda:1\"\n", + "else:\n", + " dev = \"cpu\"\n", "\n", - "\n" + "print(f\"Running on device: {dev}\")" ] }, { @@ -463,28 +513,28 @@ "outputs": [], "source": [ "model_params = {\n", - " 'param_tag': 'default',\n", - " 'gnn_type': 'RGCNConv',\n", - " 'depth': 5,\n", - " 'hidden_dimension': 300,\n", - " 'dense_layers': 2,\n", - " 'embedding_aggregation': 'concat',\n", - " 'embedding_dimension': 300,\n", - " 'input_dropout': 0.2,\n", - " 'latent_dropout': 0.1,\n", - " 'node_feature_layout': node_encoder.feature_numbers,\n", - " 'edge_feature_layout': bond_encoder.feature_numbers, \n", - " 'static_feature_dimension': geo_data[0][\"static_edge_features\"].shape[1],\n", - " 'static_rt_feature_dimension': geo_data[0][\"static_rt_features\"].shape[1],\n", - " 'output_dimension': len(DEFAULT_MODES) * 2, # per edge \n", + " \"param_tag\": \"default\",\n", + " \"gnn_type\": \"RGCNConv\",\n", + " \"depth\": 5,\n", + " \"hidden_dimension\": 300,\n", + " \"dense_layers\": 2,\n", + " \"embedding_aggregation\": \"concat\",\n", + " \"embedding_dimension\": 300,\n", + " \"input_dropout\": 0.2,\n", + " \"latent_dropout\": 0.1,\n", + " \"node_feature_layout\": node_encoder.feature_numbers,\n", + " \"edge_feature_layout\": bond_encoder.feature_numbers,\n", + " \"static_feature_dimension\": geo_data[0][\"static_edge_features\"].shape[1],\n", + " \"static_rt_feature_dimension\": geo_data[0][\"static_rt_features\"].shape[1],\n", + " \"output_dimension\": len(DEFAULT_MODES) * 2, # per edge\n", "}\n", "training_params = {\n", - " 'epochs': 200 if not test_run else 10, \n", - " 'batch_size': 256, #128,\n", + " \"epochs\": 200 if not test_run else 10,\n", + " \"batch_size\": 256, # 128,\n", " #'train_val_split': 0.90,\n", - " 'learning_rate': 0.0004,#0.001,\n", - " 'with_RT': False,\n", - " 'with_CCS': False\n", + " \"learning_rate\": 0.0004, # 0.001,\n", + " \"with_RT\": False,\n", + " \"with_CCS\": False,\n", "}" ] }, @@ -494,13 +544,22 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "fixed_params = {\"gnn_type\": \"RGCNConv\"} # Mainly used for clarity\n", - "grid_params = [{'depth': 0}, {'depth': 1}, {'depth': 2}, {'depth': 3}, {'depth': 4}, {'depth': 5}, {'depth': 6}, {'depth': 7}, {'depth': 8}]#,{'depth': 0}, {'depth': 1}, {'depth': 7}, {'depth': 8}]\n", + "fixed_params = {\"gnn_type\": \"RGCNConv\"} # Mainly used for clarity\n", + "grid_params = [\n", + " {\"depth\": 0},\n", + " {\"depth\": 1},\n", + " {\"depth\": 2},\n", + " {\"depth\": 3},\n", + " {\"depth\": 4},\n", + " {\"depth\": 5},\n", + " {\"depth\": 6},\n", + " {\"depth\": 7},\n", + " {\"depth\": 8},\n", + "] # ,{'depth': 0}, {'depth': 1}, {'depth': 7}, {'depth': 8}]\n", "for p in grid_params:\n", " p.update(fixed_params)\n", - "#grid_params = [{'gnn_type': \"GraphConv\"}, {'gnn_type': \"RGCNConv\"}, {'gnn_type': \"GAT\"}, {'gnn_type': \"TransformerConv\"}]\n", - "#grid_params = [{'embedding_dimension': 300}, {'embedding_dimension': 400}, {'embedding_dimension': 500}]" + "# grid_params = [{'gnn_type': \"GraphConv\"}, {'gnn_type': \"RGCNConv\"}, {'gnn_type': \"GAT\"}, {'gnn_type': \"TransformerConv\"}]\n", + "# grid_params = [{'embedding_dimension': 300}, {'embedding_dimension': 400}, {'embedding_dimension': 500}]" ] }, { @@ -521,46 +580,78 @@ "from fiora.GNN.Losses import WeightedMSELoss, WeightedMSEMetric\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", - "fiora = SimulationFramework(None, dev=dev, with_RT=training_params[\"with_RT\"], with_CCS=training_params[\"with_CCS\"])\n", - "np.seterr(invalid='ignore')\n", + "fiora = SimulationFramework(\n", + " None,\n", + " dev=dev,\n", + " with_RT=training_params[\"with_RT\"],\n", + " with_CCS=training_params[\"with_CCS\"],\n", + ")\n", + "np.seterr(invalid=\"ignore\")\n", "val_interval = 1\n", "tag = \"grid\"\n", "val_interval = 1\n", - "metric_dict= {\"mse\": WeightedMSEMetric}\n", + "metric_dict = {\"mse\": WeightedMSEMetric}\n", "loss_fn = WeightedMSELoss()\n", "\n", + "\n", "def train_new_model():\n", " model = GNNCompiler(model_params).to(dev)\n", - " \n", - " y_label = 'compiled_probsALL'\n", - " train_keys, val_keys = df_train[df_train[\"dataset\"] == \"training\"][\"group_id\"].unique(), df_train[df_train[\"dataset\"] == \"validation\"][\"group_id\"].unique()\n", - " trainer = Trainer(geo_data, y_tag=y_label, problem_type=\"regression\", train_keys=train_keys, val_keys=val_keys, metric_dict=metric_dict, split_by_group=True, seed=seed, device=dev)\n", - " optimizer = torch.optim.Adam(model.parameters(), lr=training_params[\"learning_rate\"])\n", - " #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98) \n", - " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # TODO doesn't work with onlyTraining\n", - " optimizer,\n", - " patience = 8, # 10 default\n", - " #factor = self.hparams['factor'],\n", - " mode = 'min',\n", - " verbose = True)\n", - " \n", - " \n", + "\n", + " y_label = \"compiled_probsALL\"\n", + " train_keys, val_keys = (\n", + " df_train[df_train[\"dataset\"] == \"training\"][\"group_id\"].unique(),\n", + " df_train[df_train[\"dataset\"] == \"validation\"][\"group_id\"].unique(),\n", + " )\n", + " trainer = Trainer(\n", + " geo_data,\n", + " y_tag=y_label,\n", + " problem_type=\"regression\",\n", + " train_keys=train_keys,\n", + " val_keys=val_keys,\n", + " metric_dict=metric_dict,\n", + " split_by_group=True,\n", + " seed=seed,\n", + " device=dev,\n", + " )\n", + " optimizer = torch.optim.Adam(\n", + " model.parameters(), lr=training_params[\"learning_rate\"]\n", + " )\n", + " # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n", + " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # TODO doesn't work with onlyTraining\n", + " optimizer,\n", + " patience=8, # 10 default\n", + " # factor = self.hparams['factor'],\n", + " mode=\"min\",\n", + " verbose=True,\n", + " )\n", + "\n", " # from accelerate import notebook_launcher\n", " # args = {\"model\": model, \"optimizer\": optimizer, \"loss_fn\": loss_fn, \"scheduler\": scheduler, \"batch_size\": training_params['batch_size'], \"epochs\": training_params[\"epochs\"], \"val_every_n_epochs\": val_interval, \"with_RT\": True, \"masked_validation\": training_params[\"with_RT\"], \"mask_name\": \"compiled_validation_maskALL\"}\n", " # notebook_launcher(trainer.train, args, num_processes=4)\n", - " checkpoints = trainer.train(model, optimizer, loss_fn, scheduler=scheduler, batch_size=training_params['batch_size'], epochs=training_params[\"epochs\"], val_every_n_epochs=val_interval, with_CCS=training_params[\"with_CCS\"], with_RT=training_params[\"with_RT\"], masked_validation=False, tag=tag) #, mask_name=\"compiled_validation_maskALL\")\n", + " checkpoints = trainer.train(\n", + " model,\n", + " optimizer,\n", + " loss_fn,\n", + " scheduler=scheduler,\n", + " batch_size=training_params[\"batch_size\"],\n", + " epochs=training_params[\"epochs\"],\n", + " val_every_n_epochs=val_interval,\n", + " with_CCS=training_params[\"with_CCS\"],\n", + " with_RT=training_params[\"with_RT\"],\n", + " masked_validation=False,\n", + " tag=tag,\n", + " ) # , mask_name=\"compiled_validation_maskALL\")\n", " print(checkpoints)\n", " return model, checkpoints\n", "\n", "\n", - "\n", "def simulate_all(model, DF):\n", " return fiora.simulate_all(DF, model)\n", "\n", - " \n", + "\n", "def test_model(model, DF):\n", " dft = simulate_all(model, DF)\n", - " \n", + "\n", " return dft[\"spectral_sqrt_cosine\"].values" ] }, @@ -578,48 +669,95 @@ "outputs": [], "source": [ "from fiora.MOL.collision_energy import NCE_to_eV\n", - "from fiora.MS.spectral_scores import spectral_cosine, spectral_reflection_cosine, reweighted_dot\n", + "from fiora.MS.spectral_scores import (\n", + " spectral_cosine,\n", + " spectral_reflection_cosine,\n", + " reweighted_dot,\n", + ")\n", "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", + "\n", "def test_cas16(model, df_cas=df_cas):\n", - " \n", - " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + "\n", + " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", "\n", - " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", "\n", - "\n", - " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", "\n", - " df_cas[\"avg_CE\"] = (df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]) / 3\n", + " df_cas[\"avg_CE\"] = (\n", + " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " ) / 3\n", "\n", - " df_cas[\"merged_peaks\"] = df_cas.apply(lambda x: merge_annotated_spectrum(merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]), x[\"sim_peaks_50\"]) , axis=1)\n", - " df_cas[\"merged_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(lambda x: spectral_reflection_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_steins\"] = df_cas.apply(lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\"merged_sqrt_cosine\"] # just remember it is merged\n", + " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " lambda x: merge_annotated_spectrum(\n", + " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", + " x[\"sim_peaks_50\"],\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_reflection_cosine(\n", + " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_steins\"] = df_cas.apply(\n", + " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", + " \"merged_sqrt_cosine\"\n", + " ] # just remember it is merged\n", "\n", " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", " df_cas[\"RT_pred\"] = df_cas[\"RT_pred_35\"]\n", " df_cas[\"RT_dif\"] = df_cas[\"RT_dif_35\"]\n", " df_cas[\"CCS_pred\"] = df_cas[\"CCS_pred_35\"]\n", " df_cas[\"library\"] = \"CASMI-16\"\n", - " \n", + "\n", " return df_cas[\"merged_sqrt_cosine\"].values" ] }, @@ -2720,29 +2858,71 @@ " print(f\"Testing {params}\")\n", " model_params.update(params)\n", " current_model, checkpoint = train_new_model()\n", - " val_results = test_model(current_model, df_train[df_train[\"dataset\"]== \"validation\"])\n", + " val_results = test_model(\n", + " current_model, df_train[df_train[\"dataset\"] == \"validation\"]\n", + " )\n", " test_results = test_model(current_model, df_test)\n", " casmi16_results = test_cas16(current_model)\n", " casmi16_p = test_cas16(current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"])\n", " casmi16_n = test_cas16(current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"])\n", " casmi22_results = test_model(current_model, df_cas22)\n", - " casmi22_p = test_model(current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"])\n", - " casmi22_n = test_model(current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"])\n", - " \n", - " results.append({**params, \"model\": copy.deepcopy(current_model), \"cp\": checkpoint, \"validation\": val_results, \"test\": test_results, \"casmi16\": casmi16_results, \"casmi22\": casmi22_results, \"casmi16+\": casmi16_p, \"casmi16-\": casmi16_n, \"casmi22+\": casmi22_p, \"casmi22-\": casmi22_n})\n", - " \n", + " casmi22_p = test_model(\n", + " current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"]\n", + " )\n", + " casmi22_n = test_model(\n", + " current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"]\n", + " )\n", + "\n", + " results.append(\n", + " {\n", + " **params,\n", + " \"model\": copy.deepcopy(current_model),\n", + " \"cp\": checkpoint,\n", + " \"validation\": val_results,\n", + " \"test\": test_results,\n", + " \"casmi16\": casmi16_results,\n", + " \"casmi22\": casmi22_results,\n", + " \"casmi16+\": casmi16_p,\n", + " \"casmi16-\": casmi16_n,\n", + " \"casmi22+\": casmi22_p,\n", + " \"casmi22-\": casmi22_n,\n", + " }\n", + " )\n", + "\n", " current_model = current_model.load(checkpoint[\"file\"])\n", - " val_results_cp = test_model(current_model, df_train[df_train[\"dataset\"]== \"validation\"])\n", + " val_results_cp = test_model(\n", + " current_model, df_train[df_train[\"dataset\"] == \"validation\"]\n", + " )\n", " test_results_cp = test_model(current_model, df_test)\n", " casmi16_results_cp = test_cas16(current_model)\n", - " casmi16_p_cp = test_cas16(current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"])\n", - " casmi16_n_cp = test_cas16(current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"])\n", + " casmi16_p_cp = test_cas16(\n", + " current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"]\n", + " )\n", + " casmi16_n_cp = test_cas16(\n", + " current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"]\n", + " )\n", " casmi22_results_cp = test_model(current_model, df_cas22)\n", - " casmi22_p_cp = test_model(current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"])\n", - " casmi22_n_cp = test_model(current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"])\n", - " results_cp.append({**params, \"model\": copy.deepcopy(current_model), \"cp\": checkpoint, \"validation\": val_results_cp, \"test\": test_results_cp, \"casmi16\": casmi16_results_cp, \"casmi22\": casmi22_results_cp, \"casmi16+\": casmi16_p_cp, \"casmi16-\": casmi16_n_cp, \"casmi22+\": casmi22_p_cp, \"casmi22-\": casmi22_n_cp})\n", - " \n", - " " + " casmi22_p_cp = test_model(\n", + " current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"]\n", + " )\n", + " casmi22_n_cp = test_model(\n", + " current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"]\n", + " )\n", + " results_cp.append(\n", + " {\n", + " **params,\n", + " \"model\": copy.deepcopy(current_model),\n", + " \"cp\": checkpoint,\n", + " \"validation\": val_results_cp,\n", + " \"test\": test_results_cp,\n", + " \"casmi16\": casmi16_results_cp,\n", + " \"casmi22\": casmi22_results_cp,\n", + " \"casmi16+\": casmi16_p_cp,\n", + " \"casmi16-\": casmi16_n_cp,\n", + " \"casmi22+\": casmi22_p_cp,\n", + " \"casmi22-\": casmi22_n_cp,\n", + " }\n", + " )" ] }, { @@ -2762,7 +2942,7 @@ } ], "source": [ - "np.median(val_results)\n" + "np.median(val_results)" ] }, { @@ -3177,8 +3357,8 @@ "source": [ "LOGIC = pd.read_csv(home_path + NAME, sep=\"\\t\")\n", "for col in eval_columns:\n", - " LOGIC[col] = LOGIC[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", - "#LOGIC[eval_columns].apply(lambda x: x.apply(np.median))" + " LOGIC[col] = LOGIC[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + "# LOGIC[eval_columns].apply(lambda x: x.apply(np.median))" ] }, { @@ -3403,7 +3583,7 @@ ], "source": [ "LOGIC[eval_columns] = LOGIC[eval_columns].apply(lambda x: x.apply(np.median))\n", - "LOGIC\n" + "LOGIC" ] } ], diff --git a/notebooks/grid_stats.ipynb b/notebooks/grid_stats.ipynb index 5066290..5fc3b48 100644 --- a/notebooks/grid_stats.ipynb +++ b/notebooks/grid_stats.ipynb @@ -40,12 +40,14 @@ "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree \n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", "from fiora.visualization.define_colors import set_light_theme\n", "from fiora.visualization.define_colors import *\n", + "\n", "set_light_theme()\n", "import fiora.visualization.spectrum_visualizer as sv\n", "\n", @@ -56,9 +58,10 @@ "from sklearn.metrics import r2_score\n", "import scipy\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f'Working with Python {sys.version}')\n" + "RDLogger.DisableLog(\"rdApp.*\")\n", + "\n", + "print(f\"Working with Python {sys.version}\")" ] }, { @@ -80,7 +83,7 @@ " eval_columns = LOG.columns[3:]\n", "\n", " for col in eval_columns:\n", - " LOG[col] = LOG[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", + " LOG[col] = LOG[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", " return LOG" ] }, @@ -92,7 +95,7 @@ "source": [ "path = f\"{home}/data/metabolites/benchmarking/\"\n", "suffix = \"_cp_Jan24.csv\"\n", - "NAMES = [\"GraphConv\", \"RGCNConv\", \"GAT\", \"TransformerConv\"] # , \"CGConv_depth.csv\",\n", + "NAMES = [\"GraphConv\", \"RGCNConv\", \"GAT\", \"TransformerConv\"] # , \"CGConv_depth.csv\",\n", "\n", "log = []\n", "for name in NAMES:\n", @@ -1075,7 +1078,7 @@ } ], "source": [ - "log\n" + "log" ] }, { @@ -1084,7 +1087,13 @@ "metadata": {}, "outputs": [], "source": [ - "gnn_type_labels={\"GraphConv\": \"GCN\", \"CGConv\": \"CGC\", \"GAT\": \"GAT\", \"RGCNConv\": \"RGCN\", \"TransformerConv\": \"Transformer\"}" + "gnn_type_labels = {\n", + " \"GraphConv\": \"GCN\",\n", + " \"CGConv\": \"CGC\",\n", + " \"GAT\": \"GAT\",\n", + " \"RGCNConv\": \"RGCN\",\n", + " \"TransformerConv\": \"Transformer\",\n", + "}" ] }, { @@ -1125,30 +1134,56 @@ } ], "source": [ - "fig, ax = plt.subplots(1,1, figsize=(12, 6))\n", + "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", "plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", - "color_blind4=[sns.color_palette(\"colorblind\")[4], sns.color_palette(\"colorblind\")[0], sns.color_palette(\"colorblind\")[2], sns.color_palette(\"colorblind\")[3]]\n", + "color_blind4 = [\n", + " sns.color_palette(\"colorblind\")[4],\n", + " sns.color_palette(\"colorblind\")[0],\n", + " sns.color_palette(\"colorblind\")[2],\n", + " sns.color_palette(\"colorblind\")[3],\n", + "]\n", "L = log.explode(\"validation\")\n", "L[\"gnn_type\"] = L[\"gnn_type\"].map(gnn_type_labels)\n", - "sns.pointplot(data=L, x=\"depth\", y=\"validation\", estimator=\"median\", capsize=.0, markers=\"o\", palette=color_blind4, markersize=5, errorbar=('ci', 95), linestyles='--', hue=\"gnn_type\", dodge=0.4) # ci=('ci', 0.95),, palette=tri_palette, bins=[x/10.0 for x in list(range(0,10,1))]), multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", - "#plt.ylim([0.4, 0.8]) # markers=[\"o\", \"X\", \"^\",\"*\"]\n", + "sns.pointplot(\n", + " data=L,\n", + " x=\"depth\",\n", + " y=\"validation\",\n", + " estimator=\"median\",\n", + " capsize=0.0,\n", + " markers=\"o\",\n", + " palette=color_blind4,\n", + " markersize=5,\n", + " errorbar=(\"ci\", 95),\n", + " linestyles=\"--\",\n", + " hue=\"gnn_type\",\n", + " dodge=0.4,\n", + ") # ci=('ci', 0.95),, palette=tri_palette, bins=[x/10.0 for x in list(range(0,10,1))]), multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + "# plt.ylim([0.4, 0.8]) # markers=[\"o\", \"X\", \"^\",\"*\"]\n", "plt.ylabel(\"Cosine similarity\")\n", "plt.xlabel(\"Graph network depth\")\n", "plt.yticks(np.arange(0.5, 0.91, 0.05))\n", "# plt.autoscale(enable=True, axis='y')\n", "plt.ylim(0.575, 0.87)\n", - "#plt.ylim(0.25, 0.95)\n", + "# plt.ylim(0.25, 0.95)\n", "plt.legend(title=\"\")\n", "\n", "# for line in ax.lines:\n", "# marker = line.get_marker()\n", - "# line.set_markeredgecolor('black') \n", + "# line.set_markeredgecolor('black')\n", "\n", - "plt.rc('axes', labelsize=14)\n", - "plt.rc('legend', fontsize=14)\n", - "ax.tick_params(axis='both', which='major', labelsize=13)\n", - "ax.text(0.02, 0.02, \"n=7212 for all data points\", transform=ax.transAxes, fontsize=13, va='bottom', ha='left')\n", + "plt.rc(\"axes\", labelsize=14)\n", + "plt.rc(\"legend\", fontsize=14)\n", + "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "ax.text(\n", + " 0.02,\n", + " 0.02,\n", + " \"n=7212 for all data points\",\n", + " transform=ax.transAxes,\n", + " fontsize=13,\n", + " va=\"bottom\",\n", + " ha=\"left\",\n", + ")\n", "\n", "# fig.savefig(f\"{home}/images/paper/grid_params3.svg\", format=\"svg\", dpi=600)\n", "# # fig.savefig(f\"{home}/images/paper/grid_params3.png\", format=\"png\", dpi=600)\n", @@ -1173,29 +1208,55 @@ } ], "source": [ - "fig, ax = plt.subplots(1,1, figsize=(12, 6))\n", + "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", "plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", - "color_blind4=[sns.color_palette(\"colorblind\")[4], sns.color_palette(\"colorblind\")[0], sns.color_palette(\"colorblind\")[2], sns.color_palette(\"colorblind\")[3]]\n", + "color_blind4 = [\n", + " sns.color_palette(\"colorblind\")[4],\n", + " sns.color_palette(\"colorblind\")[0],\n", + " sns.color_palette(\"colorblind\")[2],\n", + " sns.color_palette(\"colorblind\")[3],\n", + "]\n", "L = log.explode(\"validation\")\n", "L[\"gnn_type\"] = L[\"gnn_type\"].map(gnn_type_labels)\n", - "sns.pointplot(data=L, x=\"depth\", y=\"validation\", estimator=\"median\", capsize=.0, markers=\"o\", palette=color_blind4, markersize=5, errorbar=('pi', 50), linestyles='--', hue=\"gnn_type\", dodge=0.4) # ci=('ci', 0.95),, palette=tri_palette, bins=[x/10.0 for x in list(range(0,10,1))]), multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", - "#plt.ylim([0.4, 0.8]) # markers=[\"o\", \"X\", \"^\",\"*\"]\n", + "sns.pointplot(\n", + " data=L,\n", + " x=\"depth\",\n", + " y=\"validation\",\n", + " estimator=\"median\",\n", + " capsize=0.0,\n", + " markers=\"o\",\n", + " palette=color_blind4,\n", + " markersize=5,\n", + " errorbar=(\"pi\", 50),\n", + " linestyles=\"--\",\n", + " hue=\"gnn_type\",\n", + " dodge=0.4,\n", + ") # ci=('ci', 0.95),, palette=tri_palette, bins=[x/10.0 for x in list(range(0,10,1))]), multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + "# plt.ylim([0.4, 0.8]) # markers=[\"o\", \"X\", \"^\",\"*\"]\n", "plt.ylabel(\"Cosine similarity\")\n", "plt.xlabel(\"Graph network depth\")\n", "plt.yticks(np.arange(0.4, 0.91, 0.05))\n", "# plt.autoscale(enable=True, axis='y')\n", "plt.ylim(0.4, 0.95)\n", - "#plt.ylim(0.25, 0.95)\n", + "# plt.ylim(0.25, 0.95)\n", "plt.legend(title=\"\")\n", "\n", "# for line in ax.lines:\n", "# marker = line.get_marker()\n", - "# line.set_markeredgecolor('black') \n", - "plt.rc('axes', labelsize=14)\n", - "plt.rc('legend', fontsize=14)\n", - "ax.tick_params(axis='both', which='major', labelsize=13)\n", - "ax.text(0.02, 0.02, \"n=7212 for all data points\", transform=ax.transAxes, fontsize=13, va='bottom', ha='left')\n", + "# line.set_markeredgecolor('black')\n", + "plt.rc(\"axes\", labelsize=14)\n", + "plt.rc(\"legend\", fontsize=14)\n", + "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "ax.text(\n", + " 0.02,\n", + " 0.02,\n", + " \"n=7212 for all data points\",\n", + " transform=ax.transAxes,\n", + " fontsize=13,\n", + " va=\"bottom\",\n", + " ha=\"left\",\n", + ")\n", "\n", "# fig.savefig(f\"{home}/images/paper/grid_params3_iqr.svg\", format=\"svg\", dpi=600)\n", "# # fig.savefig(f\"{home}/images/paper/grid_params3_iqr.png\", format=\"png\", dpi=600)\n", @@ -1209,15 +1270,18 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np \n", - "import scipy.stats as st \n", + "import numpy as np\n", + "import scipy.stats as st\n", "from scipy.stats import bootstrap\n", "\n", + "\n", "def print_stats(scores):\n", " print(f\"Num of spectra: {len(scores)}\")\n", " print(f\"Median:\\t{np.median(scores):.3f}\")\n", " print(f\"Var:\\t{np.var(scores):.3f} (Standard deviation: {np.std(scores):.3f})\")\n", - " conf_in_t = st.t.interval(confidence=0.95, df=len(scores)-1, loc=np.median(scores), scale=st.sem(scores)) \n", + " conf_in_t = st.t.interval(\n", + " confidence=0.95, df=len(scores) - 1, loc=np.median(scores), scale=st.sem(scores)\n", + " )\n", " conf_in_boot = bootstrap((scores,), np.median, confidence_level=0.95)\n", " print(f\"95%CI: {conf_in_t} (from t distribution)\")\n", " print(f\"95%CI: {conf_in_boot.confidence_interval} (from bootstrapping)\")" @@ -1416,7 +1480,9 @@ ], "source": [ "eval_columns = log.columns[4:]\n", - "log[log[\"gnn_type\"] == \"RGCNConv\"][eval_columns].apply(lambda x: x.apply(np.mean) ,axis=1)" + "log[log[\"gnn_type\"] == \"RGCNConv\"][eval_columns].apply(\n", + " lambda x: x.apply(np.mean), axis=1\n", + ")" ] }, { @@ -1480,7 +1546,7 @@ } ], "source": [ - "sns.boxplot(data=L4, y=\"validation\")\n" + "sns.boxplot(data=L4, y=\"validation\")" ] }, { @@ -2474,7 +2540,13 @@ "metadata": {}, "outputs": [], "source": [ - "L_export = L_export.rename(columns={\"depth\": \"Depth\", \"gnn_type\": \"GNN Architecture\", \"validation\": \"Cosine Similarity\"})" + "L_export = L_export.rename(\n", + " columns={\n", + " \"depth\": \"Depth\",\n", + " \"gnn_type\": \"GNN Architecture\",\n", + " \"validation\": \"Cosine Similarity\",\n", + " }\n", + ")" ] }, { @@ -2483,7 +2555,9 @@ "metadata": {}, "outputs": [], "source": [ - "L_export['Validation ID'] = L_export.groupby(['Depth', 'GNN Architecture']).cumcount() + 1" + "L_export[\"Validation ID\"] = (\n", + " L_export.groupby([\"Depth\", \"GNN Architecture\"]).cumcount() + 1\n", + ")" ] }, { @@ -2501,7 +2575,9 @@ "metadata": {}, "outputs": [], "source": [ - "L_export[[\"Depth\", \"GNN Architecture\", \"Validation ID\", \"Cosine Similarity\"]].to_excel(f\"{home}/images/paper/SourceData_Figure2.xlsx\")" + "L_export[[\"Depth\", \"GNN Architecture\", \"Validation ID\", \"Cosine Similarity\"]].to_excel(\n", + " f\"{home}/images/paper/SourceData_Figure2.xlsx\"\n", + ")" ] } ], diff --git a/notebooks/info_graphs.ipynb b/notebooks/info_graphs.ipynb index 65f145d..9e24c71 100644 --- a/notebooks/info_graphs.ipynb +++ b/notebooks/info_graphs.ipynb @@ -27,7 +27,7 @@ "import torch\n", "\n", "seed = 42\n", - "#torch.set_default_dtype(torch.float64)\n", + "# torch.set_default_dtype(torch.float64)\n", "torch.manual_seed(seed)\n", "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", @@ -40,18 +40,20 @@ "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree \n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", "import fiora.visualization.spectrum_visualizer as sv\n", "\n", "from sklearn.metrics import r2_score\n", "import scipy\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f'Working with Python {sys.version}')\n" + "RDLogger.DisableLog(\"rdApp.*\")\n", + "\n", + "print(f\"Working with Python {sys.version}\")" ] }, { @@ -69,12 +71,15 @@ ], "source": [ "from typing import Literal\n", - "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\", \"MSnLib\"] = \"MSnLib\" #\"MSnLib\"\n", + "\n", + "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\", \"MSnLib\"] = \"MSnLib\" # \"MSnLib\"\n", "print(f\"Preparing {lib} library\")\n", "\n", - "debug_mode = False # Default: False\n", + "debug_mode = False # Default: False\n", "if debug_mode:\n", - " print(\"+++ This is a test run (debug mode) with a small subset of data points. Results are not representative. +++\")" + " print(\n", + " \"+++ This is a test run (debug mode) with a small subset of data points. Results are not representative. +++\"\n", + " )" ] }, { @@ -85,41 +90,45 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\"\n", - " }\n", + " \"name\": \"Name\",\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"ionization\": \"Ionization\",\n", + " \"precursor_mz\": \"PrecursorMZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + " \"ccs\": \"CCS\",\n", + "}\n", "\n", "\n", "#\n", "# Load specified libraries and align metadata\n", "#\n", "\n", + "\n", "def load_training_data():\n", - " if (\"NIST\" in lib or \"MSDIAL\" in lib):\n", + " if \"NIST\" in lib or \"MSDIAL\" in lib:\n", " data_path: str = f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\"\n", " elif lib == \"MSnLib\":\n", - " data_path: str = f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_April25_v1.csv\"\n", + " data_path: str = (\n", + " f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_April25_v1.csv\"\n", + " )\n", " else:\n", " raise NameError(f\"Unknown library selected {lib=}.\")\n", " L = LibraryLoader()\n", " df = L.load_from_csv(data_path)\n", " return df\n", "\n", + "\n", "df = load_training_data()\n", "\n", "# Restore dictionary values\n", "dict_columns = [\"peaks\", \"summary\"]\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", - " #df[col] = df[col].apply(ast.literal_eval)\n", - " \n", - "df['group_id'] = df['group_id'].astype(int)\n" + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " # df[col] = df[col].apply(ast.literal_eval)\n", + "\n", + "df[\"group_id\"] = df[\"group_id\"].astype(int)" ] }, { @@ -139,15 +148,14 @@ "\n", "\n", "if debug_mode:\n", - " df = df.iloc[:10000,:]\n", - " #df = df.iloc[5000:20000,:]\n", + " df = df.iloc[:10000, :]\n", + " # df = df.iloc[5000:20000,:]\n", "\n", "overwrite_setup_features = None\n", "if lib == \"MSnLib\":\n", " overwrite_setup_features = {\n", " \"instrument\": [\"HCD\"],\n", - " \"precursor_mode\": [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"]\n", - " \n", + " \"precursor_mode\": [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"],\n", " }\n", "\n", "\n", @@ -156,18 +164,38 @@ "\n", "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", - "covariate_encoder = CovariateFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\", \"element_composition\"], sets_overwrite=overwrite_setup_features)\n", - "rt_encoder = CovariateFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\", \"element_composition\"], sets_overwrite=overwrite_setup_features)\n", + "covariate_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " \"element_composition\",\n", + " ],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", + "rt_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " \"element_composition\",\n", + " ],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", "\n", - "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", + "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", "\n", "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]) , axis=1)\n", + "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", "\n", - "#df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", + "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", + "df.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "_ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)" ] }, @@ -198,7 +226,9 @@ ], "source": [ "mindex.create_fragmentation_trees()\n", - "list_of_mismatched_ids = mindex.add_fragmentation_trees_to_metabolite_list(df[\"Metabolite\"], graph_mismatch_policy=\"recompute\")\n", + "list_of_mismatched_ids = mindex.add_fragmentation_trees_to_metabolite_list(\n", + " df[\"Metabolite\"], graph_mismatch_policy=\"recompute\"\n", + ")\n", "print(f\"Total number of recomputed trees: {len(list_of_mismatched_ids)}\")" ] }, @@ -208,8 +238,13 @@ "metadata": {}, "outputs": [], "source": [ - "#df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "_ = df.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]), axis=1)" + "# df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "_ = df.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -257,7 +292,19 @@ "source": [ "from fiora.MOL.MetaboliteDatasetStatistics import MetaboliteDatasetStatistics\n", "\n", - "ORDERED_ELEMENT_LIST = [\"C\", \"H\", \"O\", \"N\", \"F\", \"Cl\", \"Br\", \"I\", \"P\", \"S\", \"Si\"] # same as in constants.py, but different order\n", + "ORDERED_ELEMENT_LIST = [\n", + " \"C\",\n", + " \"H\",\n", + " \"O\",\n", + " \"N\",\n", + " \"F\",\n", + " \"Cl\",\n", + " \"Br\",\n", + " \"I\",\n", + " \"P\",\n", + " \"S\",\n", + " \"Si\",\n", + "] # same as in constants.py, but different order\n", "stats = MetaboliteDatasetStatistics(df)\n", "stats.generate_molecular_statistics()" ] @@ -299,18 +346,27 @@ "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", - "from fiora.visualization.define_colors import ELEMENT_COLORS, set_light_theme # Import the element colors\n", + "from fiora.visualization.define_colors import (\n", + " ELEMENT_COLORS,\n", + " set_light_theme,\n", + ") # Import the element colors\n", "\n", - "set_light_theme() \n", + "set_light_theme()\n", "\n", "# Extract total counts for plotting\n", - "total_counts = stats.get_statistics()['Molecular Summary']['Total Counts']\n", + "total_counts = stats.get_statistics()[\"Molecular Summary\"][\"Total Counts\"]\n", "\n", "# Drop Hydrogen completely\n", - "filtered_counts = {element: count for element, count in total_counts.items() if element != \"H\"}\n", + "filtered_counts = {\n", + " element: count for element, count in total_counts.items() if element != \"H\"\n", + "}\n", "\n", "# Define rare elements (everything except C, O, N)\n", - "rare_elements = {element: count for element, count in filtered_counts.items() if element not in [\"C\", \"O\", \"N\"]}\n", + "rare_elements = {\n", + " element: count\n", + " for element, count in filtered_counts.items()\n", + " if element not in [\"C\", \"O\", \"N\"]\n", + "}\n", "\n", "# Create the main plot (all elements including rare ones)\n", "fig, ax_main = plt.subplots(figsize=(10, 6))\n", @@ -319,21 +375,21 @@ " y=list(filtered_counts.values()),\n", " ax=ax_main,\n", " palette=[ELEMENT_COLORS[element] for element in filtered_counts.keys()],\n", - " edgecolor='black', \n", + " edgecolor=\"black\",\n", ")\n", "ax_main.set_title(\"Element Composition\")\n", "ax_main.set_xlabel(\"Element\")\n", "ax_main.set_ylabel(\"Total Count\")\n", "\n", "# Create the zoomed-in plot for rare elements\n", - "ax_zoom_loc = [0.5, 0.4, 0.65, 0.45] # [x, y, width, height] in relative coordinates\n", - "ax_zoom = fig.add_axes(ax_zoom_loc) \n", + "ax_zoom_loc = [0.5, 0.4, 0.65, 0.45] # [x, y, width, height] in relative coordinates\n", + "ax_zoom = fig.add_axes(ax_zoom_loc)\n", "sns.barplot(\n", " x=list(rare_elements.keys()),\n", " y=list(rare_elements.values()),\n", " ax=ax_zoom,\n", " palette=[ELEMENT_COLORS[element] for element in rare_elements.keys()],\n", - " edgecolor='black',\n", + " edgecolor=\"black\",\n", ")\n", "ax_zoom.set_title(\"Rare Elements\")\n", "ax_zoom.set_xlabel(\"\")\n", @@ -354,7 +410,7 @@ "\n", "# # Add an arrow from F in the main plot to F in the subplot\n", "# ax_main.annotate(\n", - "# \"\", \n", + "# \"\",\n", "# xy=(subplot_x_pos, subplot_y_pos), # Arrow end point in the subplot\n", "# xytext=(f_index_main, f_count_main + 10), # Arrow start point in the main plot\n", "# arrowprops=dict(facecolor='black', edgecolor=\"black\", arrowstyle=\"->\"),\n", @@ -372,23 +428,33 @@ "# Calculate the relative position of Fluorine (F) in the subplot\n", "subplot_x_start, subplot_y_start, subplot_width, subplot_height = ax_zoom_loc\n", "subplot_x_pos = subplot_x_start + (f_index_zoom / len(rare_elements)) * subplot_width\n", - "subplot_y_pos = subplot_y_start \n", + "subplot_y_pos = subplot_y_start\n", "\n", "# Convert subplot coordinates to global figure coordinates\n", - "subplot_x_pos, subplot_y_pos = ax_zoom.transAxes.transform((f_index_zoom / len(rare_elements) + 0.038, -0.15)) # Relative position in figure coordinates\n", - "subplot_x_pos, subplot_y_pos = ax_main.transData.inverted().transform((subplot_x_pos, subplot_y_pos)) # Convert to main plot's data coordinates\n", + "subplot_x_pos, subplot_y_pos = ax_zoom.transAxes.transform(\n", + " (f_index_zoom / len(rare_elements) + 0.038, -0.15)\n", + ") # Relative position in figure coordinates\n", + "subplot_x_pos, subplot_y_pos = ax_main.transData.inverted().transform(\n", + " (subplot_x_pos, subplot_y_pos)\n", + ") # Convert to main plot's data coordinates\n", "\n", "# Add an arrow from F in the main plot to F in the subplot\n", "ax_main.annotate(\n", - " \"\", \n", - " xy=(subplot_x_pos, subplot_y_pos), # Arrow end point (Fluorine bar in subplot in global coordinates)\n", - " xytext=(f_index_main, f_count_main + 7000), # Arrow start point (Fluorine bar in main plot)\n", - " arrowprops=dict(facecolor='black', edgecolor=\"black\", arrowstyle=\"->\"),\n", + " \"\",\n", + " xy=(\n", + " subplot_x_pos,\n", + " subplot_y_pos,\n", + " ), # Arrow end point (Fluorine bar in subplot in global coordinates)\n", + " xytext=(\n", + " f_index_main,\n", + " f_count_main + 7000,\n", + " ), # Arrow start point (Fluorine bar in main plot)\n", + " arrowprops=dict(facecolor=\"black\", edgecolor=\"black\", arrowstyle=\"->\"),\n", ")\n", "\n", "# Adjust layout and show the plot\n", "plt.tight_layout()\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -412,8 +478,14 @@ "ELEMENT_COLORS[\"ANY_RARE\"] = \"#D3D3D3\" # Light gray for ANY_RARE\n", "\n", "# Extract presence probabilities for rare elements\n", - "presence_probabilities = stats.get_statistics()['Molecular Summary']['Presence Probabilities']\n", - "rare_element_probabilities = {element: prob for element, prob in presence_probabilities.items() if element not in [\"C\", \"O\", \"N\", \"H\"]}\n", + "presence_probabilities = stats.get_statistics()[\"Molecular Summary\"][\n", + " \"Presence Probabilities\"\n", + "]\n", + "rare_element_probabilities = {\n", + " element: prob\n", + " for element, prob in presence_probabilities.items()\n", + " if element not in [\"C\", \"O\", \"N\", \"H\"]\n", + "}\n", "\n", "# Create the probability plot\n", "fig, ax_prob = plt.subplots(figsize=(8, 4))\n", @@ -423,8 +495,8 @@ " hue=list(rare_element_probabilities.keys()),\n", " ax=ax_prob,\n", " palette=[ELEMENT_COLORS[element] for element in rare_element_probabilities.keys()],\n", - " edgecolor='black',\n", - " legend=False\n", + " edgecolor=\"black\",\n", + " legend=False,\n", ")\n", "\n", "# Apply hatching manually for ANY_RARE\n", @@ -568,12 +640,15 @@ } ], "source": [ - "\n", "# Find a large molecule with S in the structure\n", - "large_molecule_with_s = df[df['Metabolite'].apply(lambda x: 'S' in x.node_elements and x.ExactMolWeight > 500)].iloc[0]\n", - "very_large_molecule_with_s = df[df['Metabolite'].apply(lambda x: 'S' in x.node_elements and x.ExactMolWeight > 900)].iloc[2]\n", - "\n", - "large_molecule_with_s[\"Metabolite\"].draw(high_res=True)\n" + "large_molecule_with_s = df[\n", + " df[\"Metabolite\"].apply(lambda x: \"S\" in x.node_elements and x.ExactMolWeight > 500)\n", + "].iloc[0]\n", + "very_large_molecule_with_s = df[\n", + " df[\"Metabolite\"].apply(lambda x: \"S\" in x.node_elements and x.ExactMolWeight > 900)\n", + "].iloc[2]\n", + "\n", + "large_molecule_with_s[\"Metabolite\"].draw(high_res=True)" ] }, { @@ -802,12 +877,12 @@ "source": [ "import torch_geometric as geom\n", "\n", - "if torch.cuda.is_available(): \n", + "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " dev = \"cuda:0\"\n", - "else: \n", + "else:\n", " dev = \"cpu\"\n", - " \n", + "\n", "print(f\"Running on device: {dev}\")" ] }, @@ -817,9 +892,10 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.1.1_OS_depth12_June25.pt\"\n", - "OLD_MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.1.1_OS_depth5_June25_ls1.pt\"\n", + "OLD_MODEL_PATH = (\n", + " f\"{home}/data/metabolites/pretrained_models/v0.1.1_OS_depth5_June25_ls1.pt\"\n", + ")\n", "\n", "from fiora.GNN.FioraModel import FioraModel\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", @@ -926,9 +1002,20 @@ "from fiora.visualization.spectrum_visualizer import plot_spectrum\n", "\n", "\n", - "large_stats = fiora.simulate_and_score(large_molecule_with_s[\"Metabolite\"], model, query_peaks=large_molecule_with_s[\"peaks\"])\n", - "print(\"Cosine large (new model)\", large_stats[\"spectral_sqrt_cosine\"], \" and \", large_stats[\"spectral_sqrt_cosine_wo_prec\"])\n", - "plot_spectrum(large_molecule_with_s, {\"peaks\": large_stats[\"sim_peaks\"]}, highlight_matches=True)\n" + "large_stats = fiora.simulate_and_score(\n", + " large_molecule_with_s[\"Metabolite\"],\n", + " model,\n", + " query_peaks=large_molecule_with_s[\"peaks\"],\n", + ")\n", + "print(\n", + " \"Cosine large (new model)\",\n", + " large_stats[\"spectral_sqrt_cosine\"],\n", + " \" and \",\n", + " large_stats[\"spectral_sqrt_cosine_wo_prec\"],\n", + ")\n", + "plot_spectrum(\n", + " large_molecule_with_s, {\"peaks\": large_stats[\"sim_peaks\"]}, highlight_matches=True\n", + ")" ] }, { @@ -969,23 +1056,51 @@ "\n", "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", - "covariate_encoder = CovariateFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\"], sets_overwrite=overwrite_setup_features)\n", - "rt_encoder = CovariateFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"], sets_overwrite=overwrite_setup_features)\n", + "covariate_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " ],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", + "rt_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", "\n", - "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", + "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", "\n", - "copy_large_molecule_with_s[\"Metabolite\"].compute_graph_attributes(node_encoder, bond_encoder)\n", + "copy_large_molecule_with_s[\"Metabolite\"].compute_graph_attributes(\n", + " node_encoder, bond_encoder\n", + ")\n", "\n", - "#df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "copy_large_molecule_with_s[\"Metabolite\"].add_metadata(copy_large_molecule_with_s[\"summary\"], covariate_encoder, rt_encoder)\n", - "#_ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)\n", + "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", + "copy_large_molecule_with_s[\"Metabolite\"].add_metadata(\n", + " copy_large_molecule_with_s[\"summary\"], covariate_encoder, rt_encoder\n", + ")\n", + "# _ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)\n", "\n", "\n", - "large_stats_old = fiora.simulate_and_score(copy_large_molecule_with_s[\"Metabolite\"], old_model, query_peaks=copy_large_molecule_with_s[\"peaks\"])\n", - "print(f\"Cosine large (old model)\", large_stats_old['spectral_sqrt_cosine'], \"and\", large_stats_old['spectral_sqrt_cosine_wo_prec'])\n", - "plot_spectrum(copy_large_molecule_with_s, {\"peaks\": large_stats_old[\"sim_peaks\"]}, highlight_matches=True)\n" + "large_stats_old = fiora.simulate_and_score(\n", + " copy_large_molecule_with_s[\"Metabolite\"],\n", + " old_model,\n", + " query_peaks=copy_large_molecule_with_s[\"peaks\"],\n", + ")\n", + "print(\n", + " f\"Cosine large (old model)\",\n", + " large_stats_old[\"spectral_sqrt_cosine\"],\n", + " \"and\",\n", + " large_stats_old[\"spectral_sqrt_cosine_wo_prec\"],\n", + ")\n", + "plot_spectrum(\n", + " copy_large_molecule_with_s,\n", + " {\"peaks\": large_stats_old[\"sim_peaks\"]},\n", + " highlight_matches=True,\n", + ")" ] }, { @@ -1022,9 +1137,22 @@ } ], "source": [ - "very_large_stats = fiora.simulate_and_score(very_large_molecule_with_s[\"Metabolite\"], model, query_peaks=very_large_molecule_with_s[\"peaks\"])\n", - "print(\"Cosine very large (new model)\", very_large_stats['spectral_sqrt_cosine'], \"and\", very_large_stats['spectral_sqrt_cosine_wo_prec'])\n", - "plot_spectrum(very_large_molecule_with_s, {\"peaks\": very_large_stats[\"sim_peaks\"]}, highlight_matches=True)" + "very_large_stats = fiora.simulate_and_score(\n", + " very_large_molecule_with_s[\"Metabolite\"],\n", + " model,\n", + " query_peaks=very_large_molecule_with_s[\"peaks\"],\n", + ")\n", + "print(\n", + " \"Cosine very large (new model)\",\n", + " very_large_stats[\"spectral_sqrt_cosine\"],\n", + " \"and\",\n", + " very_large_stats[\"spectral_sqrt_cosine_wo_prec\"],\n", + ")\n", + "plot_spectrum(\n", + " very_large_molecule_with_s,\n", + " {\"peaks\": very_large_stats[\"sim_peaks\"]},\n", + " highlight_matches=True,\n", + ")" ] }, { @@ -1061,29 +1189,55 @@ } ], "source": [ - "\n", "copy_very_large_molecule_with_s = copy.deepcopy(very_large_molecule_with_s)\n", "\n", "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", - "covariate_encoder = CovariateFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\"], sets_overwrite=overwrite_setup_features)\n", - "rt_encoder = CovariateFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"], sets_overwrite=overwrite_setup_features)\n", + "covariate_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " ],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", + "rt_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", "\n", - "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", + "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", "\n", - "copy_very_large_molecule_with_s[\"Metabolite\"].compute_graph_attributes(node_encoder, bond_encoder)\n", + "copy_very_large_molecule_with_s[\"Metabolite\"].compute_graph_attributes(\n", + " node_encoder, bond_encoder\n", + ")\n", "\n", - "#df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "copy_very_large_molecule_with_s[\"Metabolite\"].add_metadata(copy_very_large_molecule_with_s[\"summary\"], covariate_encoder, rt_encoder)\n", - "#_ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)\n", + "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", + "copy_very_large_molecule_with_s[\"Metabolite\"].add_metadata(\n", + " copy_very_large_molecule_with_s[\"summary\"], covariate_encoder, rt_encoder\n", + ")\n", + "# _ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)\n", "\n", "\n", - "very_large_stats_old = fiora.simulate_and_score(copy_very_large_molecule_with_s[\"Metabolite\"], old_model, query_peaks=copy_very_large_molecule_with_s[\"peaks\"])\n", - "print(f\"Cosine very large (old model)\", very_large_stats_old['spectral_sqrt_cosine'], \"and \", very_large_stats_old['spectral_sqrt_cosine_wo_prec'])\n", - "plot_spectrum(copy_very_large_molecule_with_s, {\"peaks\": very_large_stats_old[\"sim_peaks\"]}, highlight_matches=True)\n", - "\n" + "very_large_stats_old = fiora.simulate_and_score(\n", + " copy_very_large_molecule_with_s[\"Metabolite\"],\n", + " old_model,\n", + " query_peaks=copy_very_large_molecule_with_s[\"peaks\"],\n", + ")\n", + "print(\n", + " f\"Cosine very large (old model)\",\n", + " very_large_stats_old[\"spectral_sqrt_cosine\"],\n", + " \"and \",\n", + " very_large_stats_old[\"spectral_sqrt_cosine_wo_prec\"],\n", + ")\n", + "plot_spectrum(\n", + " copy_very_large_molecule_with_s,\n", + " {\"peaks\": very_large_stats_old[\"sim_peaks\"]},\n", + " highlight_matches=True,\n", + ")" ] } ], diff --git a/notebooks/live_predict.ipynb b/notebooks/live_predict.ipynb index 325210b..f3a887e 100644 --- a/notebooks/live_predict.ipynb +++ b/notebooks/live_predict.ipynb @@ -19,7 +19,7 @@ "\n", "\n", "seed = 42\n", - "#torch.set_default_dtype(torch.float64)\n", + "# torch.set_default_dtype(torch.float64)\n", "torch.manual_seed(seed)\n", "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", @@ -28,12 +28,13 @@ "import numpy as np\n", "import ast\n", "import copy\n", - "import matplotlib.pyplot as plt \n", + "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", "from fiora.IO.LibraryLoader import LibraryLoader\n", @@ -52,9 +53,10 @@ "from sklearn.metrics import r2_score\n", "import scipy\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f'Working with Python {sys.version}')\n" + "RDLogger.DisableLog(\"rdApp.*\")\n", + "\n", + "print(f\"Working with Python {sys.version}\")" ] }, { @@ -73,16 +75,18 @@ "depth = 6\n", "MODEL_PATH = f\"../models/fiora_OS_v0.1.0.pt\"\n", "\n", - "try: \n", + "try:\n", " model = GNNCompiler.load(MODEL_PATH)\n", "except:\n", " try:\n", - " print(f\"Warning: Failed loading the model {MODEL_PATH}. Fall back: Loading the model from state dictionary.\")\n", + " print(\n", + " f\"Warning: Failed loading the model {MODEL_PATH}. Fall back: Loading the model from state dictionary.\"\n", + " )\n", " model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n", " print(\"Model loaded from state dict without further errors.\")\n", " except:\n", " raise NameError(\"Error: Failed loading from state dict.\")\n", - " \n", + "\n", "\n", "dev = \"cuda:1\"\n", "\n", @@ -138,10 +142,21 @@ "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", "model_setup_feature_sets = None\n", "if \"setup_features_categorical_set\" in model.model_params.keys():\n", - " model_setup_feature_sets = model.model_params[\"setup_features_categorical_set\"] \n", + " model_setup_feature_sets = model.model_params[\"setup_features_categorical_set\"]\n", "\n", - "setup_encoder = SetupFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\"], sets_overwrite=model_setup_feature_sets)\n", - "rt_encoder = SetupFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"], sets_overwrite=model_setup_feature_sets)" + "setup_encoder = SetupFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " ],\n", + " sets_overwrite=model_setup_feature_sets,\n", + ")\n", + "rt_encoder = SetupFeatureEncoder(\n", + " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " sets_overwrite=model_setup_feature_sets,\n", + ")" ] }, { @@ -167,7 +182,12 @@ "from fiora.MOL.Metabolite import Metabolite\n", "\n", "smiles = \"CC(C)C(CCCN(C)CCC1=CC(=C(C=C1)OC)OC)(C#N)C2=CC(=C(C=C2)OC)OC\"\n", - "summary = {\"name\": \"Verapamil\",\"precursor_mode\": \"[M+H]+\", \"collision_energy\": 25.0, \"instrument\": \"HCD\"}\n", + "summary = {\n", + " \"name\": \"Verapamil\",\n", + " \"precursor_mode\": \"[M+H]+\",\n", + " \"collision_energy\": 25.0,\n", + " \"instrument\": \"HCD\",\n", + "}\n", "\n", "metabolite = Metabolite(smiles)" ] @@ -228,8 +248,10 @@ } ], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(12,3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - "img = metabolite.draw(ax= axs[0])\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", + "img = metabolite.draw(ax=axs[0])\n", "axs[0].set_title(summary[\"name\"])\n", "sv.plot_spectrum({\"peaks\": pred[\"sim_peaks\"]}, ax=axs[1])\n", "plt.show()" @@ -273,14 +295,18 @@ "summary[\"instrument\"] = \"Q-TOF\"\n", "metabolite.add_metadata(summary, setup_encoder, rt_encoder)\n", "pred = fiora.simulate_and_score(metabolite, model=model)\n", - "fig, axs = plt.subplots(1, 2, figsize=(12,3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - "img = metabolite.draw(ax= axs[0])\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", + "img = metabolite.draw(ax=axs[0])\n", "axs[0].set_title(summary[\"name\"])\n", "sv.plot_spectrum({\"peaks\": pred[\"sim_peaks\"]}, ax=axs[1])\n", "plt.show()\n", "\n", "if model_setup_feature_sets and \"Q-TOF\" not in model_setup_feature_sets[\"instrument\"]:\n", - " print(\"Instrument type: Q-TOF is not a default input of the selected model. The result might not be accurate.\")" + " print(\n", + " \"Instrument type: Q-TOF is not a default input of the selected model. The result might not be accurate.\"\n", + " )" ] }, { @@ -328,14 +354,14 @@ ], "source": [ "summary[\"instrument\"] = \"HCD\"\n", - "energy_levels = {\"low\": 15.0, \"moderate\": 25.0, \"high\": 35.0} \n", + "energy_levels = {\"low\": 15.0, \"moderate\": 25.0, \"high\": 35.0}\n", "for level, ce in energy_levels.items():\n", " summary[\"collision_energy\"] = ce\n", " metabolite.add_metadata(summary, setup_encoder, rt_encoder)\n", " pred = fiora.simulate_and_score(metabolite, model=model)\n", - " fig, ax = plt.subplots(1, 1, figsize=(12,4))\n", + " fig, ax = plt.subplots(1, 1, figsize=(12, 4))\n", " sv.plot_spectrum({\"peaks\": pred[\"sim_peaks\"]}, ax=ax)\n", - " plt.title(\"Collision energy level: \" + r\"$\\bf{\"f\"{level}\" + \"}$\" + f\" ({ce} eV)\")\n", + " plt.title(\"Collision energy level: \" + r\"$\\bf{\" f\"{level}\" + \"}$\" + f\" ({ce} eV)\")\n", " plt.show()" ] }, @@ -372,8 +398,12 @@ ], "source": [ "significant_peak_num = np.argmax(pred[\"sim_peaks\"][\"intensity\"])\n", - "fragment_smiles, ion_mode = pred[\"sim_peaks\"][\"annotation\"][significant_peak_num].split(\"//\")\n", - "print(f\"Most significant (non-precursor fragment) {fragment_smiles} found in ionization mode {ion_mode}.\\nThe hydrogen losses suggest a formation of a double bond somewhere in the structure below.\")\n", + "fragment_smiles, ion_mode = pred[\"sim_peaks\"][\"annotation\"][significant_peak_num].split(\n", + " \"//\"\n", + ")\n", + "print(\n", + " f\"Most significant (non-precursor fragment) {fragment_smiles} found in ionization mode {ion_mode}.\\nThe hydrogen losses suggest a formation of a double bond somewhere in the structure below.\"\n", + ")\n", "Metabolite(fragment_smiles).draw()\n", "plt.title(f\"{fragment_smiles} ({ion_mode} ionization)\")\n", "plt.show()" diff --git a/notebooks/sandbox.ipynb b/notebooks/sandbox.ipynb index c22ae23..64022e6 100644 --- a/notebooks/sandbox.ipynb +++ b/notebooks/sandbox.ipynb @@ -27,7 +27,7 @@ "import torch\n", "\n", "seed = 42\n", - "#torch.set_default_dtype(torch.float64)\n", + "# torch.set_default_dtype(torch.float64)\n", "torch.manual_seed(seed)\n", "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", @@ -40,18 +40,20 @@ "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree \n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", "import fiora.visualization.spectrum_visualizer as sv\n", "\n", "from sklearn.metrics import r2_score\n", "import scipy\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f'Working with Python {sys.version}')\n" + "RDLogger.DisableLog(\"rdApp.*\")\n", + "\n", + "print(f\"Working with Python {sys.version}\")" ] }, { @@ -78,12 +80,15 @@ ], "source": [ "from typing import Literal\n", - "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\", \"MSnLib\"] = \"MSnLib\" #\"MSnLib\"\n", + "\n", + "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\", \"MSnLib\"] = \"MSnLib\" # \"MSnLib\"\n", "print(f\"Preparing {lib} library\")\n", "\n", - "test_run = True # Default: False\n", + "test_run = True # Default: False\n", "if test_run:\n", - " print(\"+++ This is a test run with a small subset of data points. Results are not representative. +++\")" + " print(\n", + " \"+++ This is a test run with a small subset of data points. Results are not representative. +++\"\n", + " )" ] }, { @@ -94,41 +99,45 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\"\n", - " }\n", + " \"name\": \"Name\",\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"ionization\": \"Ionization\",\n", + " \"precursor_mz\": \"PrecursorMZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + " \"ccs\": \"CCS\",\n", + "}\n", "\n", "\n", "#\n", "# Load specified libraries and align metadata\n", "#\n", "\n", + "\n", "def load_training_data():\n", - " if (\"NIST\" in lib or \"MSDIAL\" in lib):\n", + " if \"NIST\" in lib or \"MSDIAL\" in lib:\n", " data_path: str = f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\"\n", " elif lib == \"MSnLib\":\n", - " data_path: str = f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_Aug24_v3.csv\"\n", + " data_path: str = (\n", + " f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_Aug24_v3.csv\"\n", + " )\n", " else:\n", " raise NameError(f\"Unknown library selected {lib=}.\")\n", " L = LibraryLoader()\n", " df = L.load_from_csv(data_path)\n", " return df\n", "\n", + "\n", "df = load_training_data()\n", "\n", "# Restore dictionary values\n", "dict_columns = [\"peaks\", \"summary\"]\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", - " #df[col] = df[col].apply(ast.literal_eval)\n", - " \n", - "df['group_id'] = df['group_id'].astype(int)\n" + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " # df[col] = df[col].apply(ast.literal_eval)\n", + "\n", + "df[\"group_id\"] = df[\"group_id\"].astype(int)" ] }, { @@ -159,15 +168,14 @@ "\n", "\n", "if test_run:\n", - " df = df.iloc[:10000,:]\n", - " #df = df.iloc[5000:20000,:]\n", + " df = df.iloc[:10000, :]\n", + " # df = df.iloc[5000:20000,:]\n", "\n", "overwrite_setup_features = None\n", "if lib == \"MSnLib\":\n", " overwrite_setup_features = {\n", " \"instrument\": [\"HCD\"],\n", - " \"precursor_mode\": [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"]\n", - " \n", + " \"precursor_mode\": [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"],\n", " }\n", "\n", "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", @@ -175,18 +183,32 @@ "\n", "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", - "setup_encoder = SetupFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\"], sets_overwrite=overwrite_setup_features)\n", - "rt_encoder = SetupFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"], sets_overwrite=overwrite_setup_features)\n", - "\n", - "setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - "setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", + "setup_encoder = SetupFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " ],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", + "rt_encoder = SetupFeatureEncoder(\n", + " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", + "\n", + "setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + "setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", "\n", "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]) , axis=1)\n", + "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", "\n", - "#df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", + "df.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)" ] }, @@ -198,7 +220,12 @@ "source": [ "%%capture\n", "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]), axis=1)" + "df.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -247,7 +274,7 @@ "# matches = [(mz, info[\"relative_intensity\"]) for (mz, info) in data[\"Metabolite\"].peak_matches.items() if info[\"relative_intensity\"] > 0.1]\n", "# print(data[\"Metabolite\"].SMILES)\n", "# data[\"Metabolite\"].draw()\n", - " \n", + "\n", "# plt.show()\n", "# #print(matches)\n", "# sv.plot_spectrum(data)\n", @@ -261,7 +288,7 @@ "metadata": {}, "outputs": [], "source": [ - "path: str = f'{home}/data/metabolites/preprocessed/rings_msnlib.csv'\n", + "path: str = f\"{home}/data/metabolites/preprocessed/rings_msnlib.csv\"\n", "df_rings = pd.read_csv(path)" ] }, @@ -274,18 +301,30 @@ "%%capture\n", "df_rings[\"Metabolite\"] = df_rings[\"SMILES\"].apply(Metabolite)\n", "df_rings[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - "df_rings[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_rings.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]) , axis=1)\n", + "df_rings[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_rings.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", "\n", "dict_columns = [\"peaks\", \"summary\"]\n", "for col in dict_columns:\n", - " df_rings[col] = df_rings[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", - "\n", - "#df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df_rings.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_rings[col] = df_rings[col].apply(\n", + " lambda x: ast.literal_eval(x.replace(\"nan\", \"None\"))\n", + " )\n", + "\n", + "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", + "df_rings.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "df_rings.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)\n", "df_rings[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_rings.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]), axis=1)" + "df_rings.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -550,7 +589,11 @@ } ], "source": [ - "for i, row in df_rings[df_rings[\"Metabolite\"].apply(lambda x: bool(x.ring_proportion == 1.0))].drop_duplicates(subset=[\"group_id\"]).iterrows():\n", + "for i, row in (\n", + " df_rings[df_rings[\"Metabolite\"].apply(lambda x: bool(x.ring_proportion == 1.0))]\n", + " .drop_duplicates(subset=[\"group_id\"])\n", + " .iterrows()\n", + "):\n", " print(i)\n", " row[\"Metabolite\"].draw(show=True)" ] @@ -593,8 +636,12 @@ "metadata": {}, "outputs": [], "source": [ - "candidate_1 = df_rings.iloc[6]#df[df[\"Metabolite\"].apply(lambda x: x.SMILES == \"Nc1n[nH]c(N)c1N=Nc1ccc(O)cc1\")].iloc[0]\n", - "candidate_2 = df_rings.loc[4192]# negative:loc[3145] # also: 2557, 3774 #df[df[\"Metabolite\"].apply(lambda x: x.SMILES == \"c1ccc2Nc3ccccc3C(N3CCNCC3)=Nc2c1\")]" + "candidate_1 = df_rings.iloc[\n", + " 6\n", + "] # df[df[\"Metabolite\"].apply(lambda x: x.SMILES == \"Nc1n[nH]c(N)c1N=Nc1ccc(O)cc1\")].iloc[0]\n", + "candidate_2 = df_rings.loc[\n", + " 4192\n", + "] # negative:loc[3145] # also: 2557, 3774 #df[df[\"Metabolite\"].apply(lambda x: x.SMILES == \"c1ccc2Nc3ccccc3C(N3CCNCC3)=Nc2c1\")]" ] }, { @@ -642,15 +689,38 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "fig, axs = plt.subplots(1, 2, figsize=(14, 4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - "img = candidate_1[\"Metabolite\"].draw(ax= axs[0])\n", "\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", - "axs[0].set_title(\"Name: \" + candidate_1[\"NAME\"] + \"\\nCollision energy: \" + str(candidate_1[\"CE\"]) + \" eV\")\n", - "\n", - "axs[1] = sv.plot_spectrum(candidate_1, None, ax=axs[1], highlight_matches=True, mz_matches=candidate_1[\"Metabolite\"].peak_matches.keys())\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(14, 4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", + "img = candidate_1[\"Metabolite\"].draw(ax=axs[0])\n", + "\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(\n", + " \"Name: \"\n", + " + candidate_1[\"NAME\"]\n", + " + \"\\nCollision energy: \"\n", + " + str(candidate_1[\"CE\"])\n", + " + \" eV\"\n", + ")\n", + "\n", + "axs[1] = sv.plot_spectrum(\n", + " candidate_1,\n", + " None,\n", + " ax=axs[1],\n", + " highlight_matches=True,\n", + " mz_matches=candidate_1[\"Metabolite\"].peak_matches.keys(),\n", + ")\n", "plt.show()\n", - "print([(candidate_1[\"peaks\"][\"mz\"][i], candidate_1[\"peaks\"][\"intensity\"][i]) for i in range(len((candidate_1[\"peaks\"][\"mz\"]))) if candidate_1[\"peaks\"][\"intensity\"][i] > 20])" + "print(\n", + " [\n", + " (candidate_1[\"peaks\"][\"mz\"][i], candidate_1[\"peaks\"][\"intensity\"][i])\n", + " for i in range(len((candidate_1[\"peaks\"][\"mz\"])))\n", + " if candidate_1[\"peaks\"][\"intensity\"][i] > 20\n", + " ]\n", + ")" ] }, { @@ -678,15 +748,38 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "fig, axs = plt.subplots(1, 2, figsize=(14, 4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - "img = candidate_2[\"Metabolite\"].draw(ax= axs[0])\n", - "\n", - "axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", - "axs[0].set_title(\"Name: \" + candidate_2[\"NAME\"] + \"\\nCollision energy: \" + str(candidate_2[\"CE\"]) + \" eV\")\n", "\n", - "axs[1] = sv.plot_spectrum(candidate_2, None, ax=axs[1], highlight_matches=True, mz_matches=candidate_2[\"Metabolite\"].peak_matches.keys())\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(14, 4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", + "img = candidate_2[\"Metabolite\"].draw(ax=axs[0])\n", + "\n", + "axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + ")\n", + "axs[0].set_title(\n", + " \"Name: \"\n", + " + candidate_2[\"NAME\"]\n", + " + \"\\nCollision energy: \"\n", + " + str(candidate_2[\"CE\"])\n", + " + \" eV\"\n", + ")\n", + "\n", + "axs[1] = sv.plot_spectrum(\n", + " candidate_2,\n", + " None,\n", + " ax=axs[1],\n", + " highlight_matches=True,\n", + " mz_matches=candidate_2[\"Metabolite\"].peak_matches.keys(),\n", + ")\n", "plt.show()\n", - "print([(candidate_2[\"peaks\"][\"mz\"][i], candidate_2[\"peaks\"][\"intensity\"][i]) for i in range(len((candidate_2[\"peaks\"][\"mz\"]))) if candidate_2[\"peaks\"][\"intensity\"][i] > 20])" + "print(\n", + " [\n", + " (candidate_2[\"peaks\"][\"mz\"][i], candidate_2[\"peaks\"][\"intensity\"][i])\n", + " for i in range(len((candidate_2[\"peaks\"][\"mz\"])))\n", + " if candidate_2[\"peaks\"][\"intensity\"][i] > 20\n", + " ]\n", + ")" ] }, { @@ -701,6 +794,7 @@ "# Create a dictionary for exact weights of all elements\n", "element_weights = {el.symbol: el.mass for el in elements if el.mass is not None}\n", "\n", + "\n", "def get_exact_mass(elem_dict: Dict[str, int], precursor_mass: float = None) -> float:\n", " exact_mass = sum([element_weights[key] * value for key, value in elem_dict.items()])\n", " if precursor_mass is not None:\n", @@ -742,7 +836,12 @@ } ], "source": [ - "print(get_exact_mass({\"C\": 1, \"H\": 1, \"N\": 1}, candidate_2[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=None)))" + "print(\n", + " get_exact_mass(\n", + " {\"C\": 1, \"H\": 1, \"N\": 1},\n", + " candidate_2[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=None),\n", + " )\n", + ")" ] }, { @@ -759,7 +858,12 @@ } ], "source": [ - "print(get_exact_mass({\"C\": 3, \"H\": 3, \"N\": 1}, candidate_2[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=None)))" + "print(\n", + " get_exact_mass(\n", + " {\"C\": 3, \"H\": 3, \"N\": 1},\n", + " candidate_2[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=None),\n", + " )\n", + ")" ] }, { @@ -942,23 +1046,35 @@ "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", "df_cas[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", - "\n", - "metadata_key_map16 = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\"\n", - " }\n", + "df_cas[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", + "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "\n", + "metadata_key_map16 = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"PRECURSOR_MZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + "}\n", "\n", - "df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - "df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1)\n", + "df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + ")\n", + "df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1\n", + ")\n", "\n", "# Fragmentation\n", "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_cas.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM), axis=1) # Optional: use mz_cut instead\n", + "df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " ),\n", + " axis=1,\n", + ") # Optional: use mz_cut instead\n", "\n", "#\n", "# CASMI 22\n", @@ -967,22 +1083,37 @@ "df_cas22[\"Metabolite\"] = df_cas22[\"SMILES\"].apply(Metabolite)\n", "df_cas22[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cas22[\"CE\"] = df_cas22.apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1)\n", - "\n", - "metadata_key_map22 = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"precursor_mz\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"ChallengeRT\"\n", - " }\n", + "df_cas22[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cas22[\"CE\"] = df_cas22.apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1\n", + ")\n", + "\n", + "metadata_key_map22 = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"precursor_mz\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"ChallengeRT\",\n", + "}\n", "\n", - "df_cas22[\"summary\"] = df_cas22.apply(lambda x: {key: x[name] for key, name in metadata_key_map22.items()}, axis=1)\n", - "df_cas22.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + "df_cas22[\"summary\"] = df_cas22.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map22.items()}, axis=1\n", + ")\n", + "df_cas22.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "\n", "# Fragmentation\n", "df_cas22[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_cas22.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM), axis=1) # Optional: use mz_cut instead\n", + "df_cas22.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " ),\n", + " axis=1,\n", + ") # Optional: use mz_cut instead\n", "\n", "df_cas22 = df_cas22.reset_index()" ] @@ -1012,15 +1143,13 @@ "from fiora.GNN.Trainer import Trainer\n", "import torch_geometric as geom\n", "\n", - "if torch.cuda.is_available(): \n", + "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " dev = \"cuda:1\"\n", - "else: \n", - " dev = \"cpu\" \n", + "else:\n", + " dev = \"cpu\"\n", "\n", - "print(f\"Running on device: {dev}\")\n", - "\n", - "\n" + "print(f\"Running on device: {dev}\")" ] }, { @@ -1088,41 +1217,38 @@ "outputs": [], "source": [ "model_params = {\n", - " 'param_tag': 'default',\n", - " 'gnn_type': 'RGCNConv',\n", - " 'depth': 6,\n", - " 'hidden_dimension': 300,\n", - " 'dense_layers': 2,\n", - " 'embedding_aggregation': 'concat',\n", - " 'embedding_dimension': 300,\n", - " 'input_dropout': 0.2,\n", - " 'latent_dropout': 0.1,\n", - " 'node_feature_layout': node_encoder.feature_numbers,\n", - " 'edge_feature_layout': bond_encoder.feature_numbers, \n", - " 'static_feature_dimension': geo_data[0][\"static_edge_features\"].shape[1],\n", - " 'static_rt_feature_dimension': geo_data[0][\"static_rt_features\"].shape[1],\n", - " 'output_dimension': len(DEFAULT_MODES) * 2, # per edge \n", - " \n", + " \"param_tag\": \"default\",\n", + " \"gnn_type\": \"RGCNConv\",\n", + " \"depth\": 6,\n", + " \"hidden_dimension\": 300,\n", + " \"dense_layers\": 2,\n", + " \"embedding_aggregation\": \"concat\",\n", + " \"embedding_dimension\": 300,\n", + " \"input_dropout\": 0.2,\n", + " \"latent_dropout\": 0.1,\n", + " \"node_feature_layout\": node_encoder.feature_numbers,\n", + " \"edge_feature_layout\": bond_encoder.feature_numbers,\n", + " \"static_feature_dimension\": geo_data[0][\"static_edge_features\"].shape[1],\n", + " \"static_rt_feature_dimension\": geo_data[0][\"static_rt_features\"].shape[1],\n", + " \"output_dimension\": len(DEFAULT_MODES) * 2, # per edge\n", " # Keep track of encoded features\n", - " 'atom_features': node_encoder.feature_list,\n", - " 'atom_features': bond_encoder.feature_list,\n", - " 'setup_features': setup_encoder.feature_list,\n", - " 'setup_features_categorical_set': setup_encoder.categorical_sets,\n", - " 'rt_features': rt_encoder.feature_list,\n", - " \n", + " \"atom_features\": node_encoder.feature_list,\n", + " \"atom_features\": bond_encoder.feature_list,\n", + " \"setup_features\": setup_encoder.feature_list,\n", + " \"setup_features_categorical_set\": setup_encoder.categorical_sets,\n", + " \"rt_features\": rt_encoder.feature_list,\n", " # Set default flags (May be overwritten below)\n", - " 'rt_supported': False,\n", - " 'ccs_supported': False,\n", - " 'version': \"x.x.x\"\n", - " \n", + " \"rt_supported\": False,\n", + " \"ccs_supported\": False,\n", + " \"version\": \"x.x.x\",\n", "}\n", "training_params = {\n", - " 'epochs': 200 if not test_run else 10, \n", - " 'batch_size': 256,\n", + " \"epochs\": 200 if not test_run else 10,\n", + " \"batch_size\": 256,\n", " #'train_val_split': 0.90,\n", - " 'learning_rate': 0.0004, # 0.00001 currently for wMAE # Default for wMSE is 0.0004, #0.001,\n", - " 'with_RT': False, # Turn off RT/CCS for initial trainings round\n", - " 'with_CCS': False\n", + " \"learning_rate\": 0.0004, # 0.00001 currently for wMAE # Default for wMSE is 0.0004, #0.001,\n", + " \"with_RT\": False, # Turn off RT/CCS for initial trainings round\n", + " \"with_CCS\": False,\n", "}" ] }, @@ -1141,51 +1267,95 @@ "outputs": [], "source": [ "from fiora.GNN.GNNModules import GNNCompiler\n", - "from fiora.GNN.Losses import WeightedMSELoss, WeightedMSEMetric, WeightedMAELoss, WeightedMAEMetric\n", + "from fiora.GNN.Losses import (\n", + " WeightedMSELoss,\n", + " WeightedMSEMetric,\n", + " WeightedMAELoss,\n", + " WeightedMAEMetric,\n", + ")\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", "fiora = SimulationFramework(None, dev=dev, with_RT=True, with_CCS=True)\n", "# fiora = SimulationFramework(None, dev=dev, with_RT=training_params[\"with_RT\"], with_CCS=training_params[\"with_CCS\"])\n", - "np.seterr(invalid='ignore')\n", + "np.seterr(invalid=\"ignore\")\n", "tag = \"training\"\n", "val_interval = 1\n", - "metric_dict= {\"mse\": WeightedMSEMetric} #WeightedMSEMetric\n", - "loss_fn = WeightedMSELoss() # WeightedMSELoss()\n", + "metric_dict = {\"mse\": WeightedMSEMetric} # WeightedMSEMetric\n", + "loss_fn = WeightedMSELoss() # WeightedMSELoss()\n", "all_together = False\n", "\n", "if all_together:\n", " val_interval = 200\n", - " metric_dict=None\n", - " loss_fn = torch.nn.MSELoss() \n", + " metric_dict = None\n", + " loss_fn = torch.nn.MSELoss()\n", + "\n", "\n", "def train_new_model(continue_with_model=None):\n", " if continue_with_model:\n", " model = continue_with_model.to(dev)\n", " else:\n", " model = GNNCompiler(model_params).to(dev)\n", - " \n", - " y_label = 'compiled_probsSQRT' # y_label = 'compiled_probsALL'\n", - " optimizer = torch.optim.Adam(model.parameters(), lr=training_params[\"learning_rate\"])\n", + "\n", + " y_label = \"compiled_probsSQRT\" # y_label = 'compiled_probsALL'\n", + " optimizer = torch.optim.Adam(\n", + " model.parameters(), lr=training_params[\"learning_rate\"]\n", + " )\n", " if all_together:\n", - " trainer = Trainer(geo_data, y_tag=y_label, problem_type=\"regression\", only_training=True, metric_dict=metric_dict, split_by_group=True, seed=seed, device=dev)\n", + " trainer = Trainer(\n", + " geo_data,\n", + " y_tag=y_label,\n", + " problem_type=\"regression\",\n", + " only_training=True,\n", + " metric_dict=metric_dict,\n", + " split_by_group=True,\n", + " seed=seed,\n", + " device=dev,\n", + " )\n", " scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n", " else:\n", - " train_keys, val_keys = df[df[\"dataset\"] == \"training\"][\"group_id\"].unique(), df[df[\"dataset\"] == \"validation\"][\"group_id\"].unique()\n", - " trainer = Trainer(geo_data, y_tag=y_label, problem_type=\"regression\", train_keys=train_keys, val_keys=val_keys, metric_dict=metric_dict, split_by_group=True, seed=seed, device=dev)\n", - " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 8, factor=0.5, mode = 'min', verbose = True)\n", - "\n", - " \n", - " checkpoints = trainer.train(model, optimizer, loss_fn, scheduler=scheduler, batch_size=training_params['batch_size'], epochs=training_params[\"epochs\"], val_every_n_epochs=1, with_CCS=training_params[\"with_CCS\"], with_RT=training_params[\"with_RT\"], masked_validation=False, tag=tag) #, mask_name=\"compiled_validation_maskALL\") \n", + " train_keys, val_keys = (\n", + " df[df[\"dataset\"] == \"training\"][\"group_id\"].unique(),\n", + " df[df[\"dataset\"] == \"validation\"][\"group_id\"].unique(),\n", + " )\n", + " trainer = Trainer(\n", + " geo_data,\n", + " y_tag=y_label,\n", + " problem_type=\"regression\",\n", + " train_keys=train_keys,\n", + " val_keys=val_keys,\n", + " metric_dict=metric_dict,\n", + " split_by_group=True,\n", + " seed=seed,\n", + " device=dev,\n", + " )\n", + " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optimizer, patience=8, factor=0.5, mode=\"min\", verbose=True\n", + " )\n", + "\n", + " checkpoints = trainer.train(\n", + " model,\n", + " optimizer,\n", + " loss_fn,\n", + " scheduler=scheduler,\n", + " batch_size=training_params[\"batch_size\"],\n", + " epochs=training_params[\"epochs\"],\n", + " val_every_n_epochs=1,\n", + " with_CCS=training_params[\"with_CCS\"],\n", + " with_RT=training_params[\"with_RT\"],\n", + " masked_validation=False,\n", + " tag=tag,\n", + " ) # , mask_name=\"compiled_validation_maskALL\")\n", " print(checkpoints)\n", " return model, checkpoints\n", "\n", + "\n", "def simulate_all(model, DF):\n", " return fiora.simulate_all(DF, model)\n", "\n", - " \n", + "\n", "def test_model(model, DF, score=\"spectral_sqrt_cosine\", return_df=False):\n", " dft = simulate_all(model, DF)\n", - " \n", + "\n", " if return_df:\n", " return dft\n", " return dft[score].values" @@ -1205,53 +1375,112 @@ "outputs": [], "source": [ "from fiora.MOL.collision_energy import NCE_to_eV\n", - "from fiora.MS.spectral_scores import spectral_cosine, spectral_reflection_cosine, reweighted_dot\n", + "from fiora.MS.spectral_scores import (\n", + " spectral_cosine,\n", + " spectral_reflection_cosine,\n", + " reweighted_dot,\n", + ")\n", "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", + "\n", "def test_cas16(model, df_cas=df_cas, score=\"merged_sqrt_cosine\", return_df=False):\n", - " \n", - " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + "\n", + " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", "\n", - " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", "\n", - "\n", - " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", "\n", - " df_cas[\"avg_CE\"] = (df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]) / 3\n", - "\n", - " df_cas[\"merged_peaks\"] = df_cas.apply(lambda x: merge_annotated_spectrum(merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]), x[\"sim_peaks_50\"]) , axis=1)\n", - " df_cas[\"merged_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_sqrt_cosine_wo_prec\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt, remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(x[\"Metabolite\"].metadata[\"precursor_mode\"])), axis=1)\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(lambda x: spectral_reflection_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_steins\"] = df_cas.apply(lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\"merged_sqrt_cosine\"] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine_wo_prec\"] = df_cas[\"merged_sqrt_cosine_wo_prec\"] # just remember it is merged\n", + " df_cas[\"avg_CE\"] = (\n", + " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " ) / 3\n", + "\n", + " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " lambda x: merge_annotated_spectrum(\n", + " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", + " x[\"sim_peaks_50\"],\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_sqrt_cosine_wo_prec\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(\n", + " x[\"peaks\"],\n", + " x[\"merged_peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " x[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_reflection_cosine(\n", + " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_steins\"] = df_cas.apply(\n", + " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", + " \"merged_sqrt_cosine\"\n", + " ] # just remember it is merged\n", + " df_cas[\"spectral_sqrt_cosine_wo_prec\"] = df_cas[\n", + " \"merged_sqrt_cosine_wo_prec\"\n", + " ] # just remember it is merged\n", "\n", " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", " df_cas[\"RT_pred\"] = df_cas[\"RT_pred_35\"]\n", " df_cas[\"RT_dif\"] = df_cas[\"RT_dif_35\"]\n", " df_cas[\"CCS_pred\"] = df_cas[\"CCS_pred_35\"]\n", " df_cas[\"library\"] = \"CASMI-16\"\n", - " \n", + "\n", " if return_df:\n", " return df_cas\n", - " \n", + "\n", " return df_cas[score].values" ] }, @@ -1285,7 +1514,9 @@ " if \"phospho\" in d[\"name\"]:\n", " d.update({\"ce_steps\": [20, 30, 40]})\n", " return d\n", - "df_train[\"summary\"] = df_train[\"summary\"].apply(add_ce) " + "\n", + "\n", + "df_train[\"summary\"] = df_train[\"summary\"].apply(add_ce)" ] }, { @@ -1319,8 +1550,10 @@ } ], "source": [ - "\n", - "df_train.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder), axis=1)\n", + "df_train.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "geo_data = df_train[\"Metabolite\"].apply(lambda x: x.as_geometric_data().to(dev)).values\n", "print(f\"Prepared training/validation with {len(geo_data)} data points\")" ] @@ -1334,6 +1567,7 @@ "import torch_geometric.loader as geom_loader\n", "from torch_geometric.data import Data\n", "\n", + "\n", "class DoubleBatchLoader(geom_loader.DataLoader):\n", " def __iter__(self):\n", " for batch in super().__iter__():\n", @@ -1341,14 +1575,11 @@ " for key, value in batch:\n", " # Duplicate each graph in an ordered way\n", " doubled_batch[key] = torch.cat([value, value], dim=0)\n", - " \n", + "\n", " # Update batch index to reflect duplication in order\n", " doubled_batch[\"batch\"] = torch.repeat_interleave(batch[\"batch\"], 2)\n", - " \n", - " yield doubled_batch\n", - "\n", "\n", - "\n" + " yield doubled_batch" ] }, { @@ -1372,9 +1603,8 @@ "loader_base = geom_loader.DataLoader\n", "dataloader = loader_base(geo_data, batch_size=5, num_workers=0, shuffle=False)\n", "for id, batch in enumerate(dataloader):\n", - "\n", " print(batch[\"ce_steps\"])\n", - " break\n" + " break" ] }, { @@ -1417,44 +1647,55 @@ ], "source": [ "for id, batch in enumerate(dataloader):\n", - " \n", - " print(batch[\"edge_index\"][0,:][:10])\n", - " print(torch.repeat_interleave(batch[\"edge_index\"][0,:], 5)[:50])\n", - " print(batch[\"batch\"][torch.repeat_interleave(batch[\"edge_index\"][0,:], 5)])\n", - " break\n", - " # Feed forward\n", - " model.train()\n", - " \n", - " y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS)\n", - " kwargs={}\n", - " if with_weights:\n", - " kwargs={\"weight\": batch[\"weight_tensor\"]}\n", - " \n", - " loss = loss_fn(y_pred[\"fragment_probs\"], batch[self.y_tag], **kwargs) # with logits\n", - " if not rt_metric: metrics(y_pred[\"fragment_probs\"], batch[self.y_tag], **kwargs) # call update\n", - "\n", - " # Add RT and CCS to loss\n", - " if with_RT:\n", - " if with_weights:\n", - " kwargs[\"weight\"] = batch[\"weight\"][batch[\"retention_mask\"]]\n", - " loss_rt = loss_fn(y_pred[\"rt\"][batch[\"retention_mask\"]], batch[\"retention_time\"][batch[\"retention_mask\"]], **kwargs) \n", - " loss = loss + loss_rt\n", - " \n", - " if with_CCS:\n", - " if with_weights:\n", - " kwargs[\"weight\"] = batch[\"weight\"][batch[\"ccs_mask\"]]\n", - " loss_ccs = loss_fn(y_pred[\"ccs\"][batch[\"ccs_mask\"]], batch[\"ccs\"][batch[\"ccs_mask\"]], **kwargs) \n", - " loss = loss + loss_ccs\n", - "\n", - " if rt_metric:\n", - " metrics(y_pred[\"rt\"][batch[\"retention_mask\"]], batch[\"retention_time\"][batch[\"retention_mask\"]], **kwargs) # call update\n", - " metrics(y_pred[\"ccs\"][batch[\"ccs_mask\"]], batch[\"ccs\"][batch[\"ccs_mask\"]], **kwargs) # call update\n", - " \n", - " \n", - " # Backpropagate\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " optimizer.step()" + " print(batch[\"edge_index\"][0, :][:10])\n", + " print(torch.repeat_interleave(batch[\"edge_index\"][0, :], 5)[:50])\n", + " print(batch[\"batch\"][torch.repeat_interleave(batch[\"edge_index\"][0, :], 5)])\n", + " break\n", + " # Feed forward\n", + " model.train()\n", + "\n", + " y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS)\n", + " kwargs = {}\n", + " if with_weights:\n", + " kwargs = {\"weight\": batch[\"weight_tensor\"]}\n", + "\n", + " loss = loss_fn(y_pred[\"fragment_probs\"], batch[self.y_tag], **kwargs) # with logits\n", + " if not rt_metric:\n", + " metrics(y_pred[\"fragment_probs\"], batch[self.y_tag], **kwargs) # call update\n", + "\n", + " # Add RT and CCS to loss\n", + " if with_RT:\n", + " if with_weights:\n", + " kwargs[\"weight\"] = batch[\"weight\"][batch[\"retention_mask\"]]\n", + " loss_rt = loss_fn(\n", + " y_pred[\"rt\"][batch[\"retention_mask\"]],\n", + " batch[\"retention_time\"][batch[\"retention_mask\"]],\n", + " **kwargs,\n", + " )\n", + " loss = loss + loss_rt\n", + "\n", + " if with_CCS:\n", + " if with_weights:\n", + " kwargs[\"weight\"] = batch[\"weight\"][batch[\"ccs_mask\"]]\n", + " loss_ccs = loss_fn(\n", + " y_pred[\"ccs\"][batch[\"ccs_mask\"]], batch[\"ccs\"][batch[\"ccs_mask\"]], **kwargs\n", + " )\n", + " loss = loss + loss_ccs\n", + "\n", + " if rt_metric:\n", + " metrics(\n", + " y_pred[\"rt\"][batch[\"retention_mask\"]],\n", + " batch[\"retention_time\"][batch[\"retention_mask\"]],\n", + " **kwargs,\n", + " ) # call update\n", + " metrics(\n", + " y_pred[\"ccs\"][batch[\"ccs_mask\"]], batch[\"ccs\"][batch[\"ccs_mask\"]], **kwargs\n", + " ) # call update\n", + "\n", + " # Backpropagate\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()" ] }, { @@ -1723,7 +1964,7 @@ ], "source": [ "print(f\"Training model\")\n", - "model, checkpoints = train_new_model() # continue_with_model=model)" + "model, checkpoints = train_new_model() # continue_with_model=model)" ] }, { @@ -1743,7 +1984,7 @@ "source": [ "import copy\n", "\n", - "print(checkpoints) \n", + "print(checkpoints)\n", "print(np.sqrt(checkpoints[\"val_loss\"]))\n", "model_end = copy.deepcopy(model)" ] @@ -1843,19 +2084,37 @@ } ], "source": [ - "model = model_end#GNNCompiler.load(checkpoints[\"file\"]).to(dev)\n", + "model = model_end # GNNCompiler.load(checkpoints[\"file\"]).to(dev)\n", "score = \"spectral_sqrt_cosine\"\n", "\n", - "val_results = test_model(model, df_train[df_train[\"dataset\"]== \"validation\"], score=score)\n", + "val_results = test_model(\n", + " model, df_train[df_train[\"dataset\"] == \"validation\"], score=score\n", + ")\n", "test_results = test_model(model, df_test, score=score)\n", "casmi16_results = test_cas16(model, score=score)\n", "casmi16_p = test_cas16(model, df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"], score=score)\n", "casmi16_n = test_cas16(model, df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"], score=score)\n", "casmi22_results = test_model(model, df_cas22, score=score)\n", - "casmi22_p = test_model(model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"], score=score)\n", - "casmi22_n = test_model(model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"], score=score)\n", - " \n", - "results = [{\"model\": model, \"validation\": val_results, \"test\": test_results, \"casmi16\": casmi16_results, \"casmi22\": casmi22_results, \"casmi16+\": casmi16_p, \"casmi16-\": casmi16_n, \"casmi22+\": casmi22_p, \"casmi22-\": casmi22_n}]" + "casmi22_p = test_model(\n", + " model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"], score=score\n", + ")\n", + "casmi22_n = test_model(\n", + " model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"], score=score\n", + ")\n", + "\n", + "results = [\n", + " {\n", + " \"model\": model,\n", + " \"validation\": val_results,\n", + " \"test\": test_results,\n", + " \"casmi16\": casmi16_results,\n", + " \"casmi22\": casmi22_results,\n", + " \"casmi16+\": casmi16_p,\n", + " \"casmi16-\": casmi16_n,\n", + " \"casmi22+\": casmi22_p,\n", + " \"casmi22-\": casmi22_n,\n", + " }\n", + "]" ] }, { @@ -1954,16 +2213,34 @@ ], "source": [ "score = \"spectral_sqrt_cosine_wo_prec\"\n", - "val_results = test_model(model, df_train[df_train[\"dataset\"]== \"validation\"], score=score)\n", + "val_results = test_model(\n", + " model, df_train[df_train[\"dataset\"] == \"validation\"], score=score\n", + ")\n", "test_results = test_model(model, df_test, score=score)\n", "casmi16_results = test_cas16(model, score=score)\n", "casmi16_p = test_cas16(model, df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"], score=score)\n", "casmi16_n = test_cas16(model, df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"], score=score)\n", "casmi22_results = test_model(model, df_cas22, score=score)\n", - "casmi22_p = test_model(model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"], score=score)\n", - "casmi22_n = test_model(model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"], score=score)\n", - " \n", - "results_wop = [{\"model\": model, \"validation\": val_results, \"test\": test_results, \"casmi16\": casmi16_results, \"casmi22\": casmi22_results, \"casmi16+\": casmi16_p, \"casmi16-\": casmi16_n, \"casmi22+\": casmi22_p, \"casmi22-\": casmi22_n}]" + "casmi22_p = test_model(\n", + " model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"], score=score\n", + ")\n", + "casmi22_n = test_model(\n", + " model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"], score=score\n", + ")\n", + "\n", + "results_wop = [\n", + " {\n", + " \"model\": model,\n", + " \"validation\": val_results,\n", + " \"test\": test_results,\n", + " \"casmi16\": casmi16_results,\n", + " \"casmi22\": casmi22_results,\n", + " \"casmi16+\": casmi16_p,\n", + " \"casmi16-\": casmi16_n,\n", + " \"casmi22+\": casmi22_p,\n", + " \"casmi22-\": casmi22_n,\n", + " }\n", + "]" ] }, { @@ -2111,7 +2388,9 @@ "source": [ "LOG = pd.DataFrame(results_wop).fillna(0.0)\n", "eval_columns = LOG.columns[1:]\n", - "LOG[eval_columns] = LOG[eval_columns].apply(lambda x: x.apply(np.nan_to_num).apply(np.median))\n", + "LOG[eval_columns] = LOG[eval_columns].apply(\n", + " lambda x: x.apply(np.nan_to_num).apply(np.median)\n", + ")\n", "\n", "LOG" ] @@ -2295,11 +2574,13 @@ "metadata": {}, "outputs": [], "source": [ - "dev=\"cuda:0\"\n", - "mymy = GNNCompiler.load(MODEL_PATH) # f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_2.pt\"\n", - "#mymy.load_state_dict(torch.load(f\"{home}/data/metabolites/pretrained_models/test.pt\"))\n", + "dev = \"cuda:0\"\n", + "mymy = GNNCompiler.load(\n", + " MODEL_PATH\n", + ") # f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_2.pt\"\n", + "# mymy.load_state_dict(torch.load(f\"{home}/data/metabolites/pretrained_models/test.pt\"))\n", "mymy.eval()\n", - "mymy = mymy.to(dev)\n" + "mymy = mymy.to(dev)" ] }, { @@ -2349,7 +2630,8 @@ "outputs": [], "source": [ "import json\n", - "with open(MODEL_PATH.replace(\".pt\", \"_params.json\"), 'r') as fp:\n", + "\n", + "with open(MODEL_PATH.replace(\".pt\", \"_params.json\"), \"r\") as fp:\n", " p = json.load(fp)\n", "hh = GNNCompiler(p)\n", "hh.load_state_dict(torch.load(MODEL_PATH.replace(\".pt\", \"_state.pt\")))\n", @@ -2426,19 +2708,24 @@ "source": [ "## prepare output for for CFM-ID\n", "import os\n", + "\n", "save_df = False\n", "cfm_directory = f\"{home}/data/metabolites/cfm-id/\"\n", "name = \"test_split_negative_solutions_cfm.txt\"\n", "df_cfm = df_test[[\"group_id\", \"SMILES\", \"Precursor_type\"]]\n", - "df_n = df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"].drop_duplicates(subset='group_id', keep='first')\n", - "df_p = df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"].drop_duplicates(subset='group_id', keep='first')\n", + "df_n = df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"].drop_duplicates(\n", + " subset=\"group_id\", keep=\"first\"\n", + ")\n", + "df_p = df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"].drop_duplicates(\n", + " subset=\"group_id\", keep=\"first\"\n", + ")\n", "\n", "print(df_n.head())\n", "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", " df_n[[\"group_id\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")\n", - " \n", + "\n", " name = name.replace(\"negative\", \"positive\")\n", " file = os.path.join(cfm_directory, name)\n", " df_p[[\"group_id\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")" @@ -2508,16 +2795,16 @@ "source": [ "# Load best model\n", "\n", - "dev=\"cuda:0\"\n", - "#MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/pre_package/v0.0.1_merged_depth6_Jan24.pt\"\n", - "MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_depth6_Aug24_sqrt.pt\" # New sqrt model (improved)\n", + "dev = \"cuda:0\"\n", + "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/pre_package/v0.0.1_merged_depth6_Jan24.pt\"\n", + "MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_depth6_Aug24_sqrt.pt\" # New sqrt model (improved)\n", "\n", "try:\n", " model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n", " print(\"Model loaded from state dict without errors.\")\n", "except:\n", " raise NameError(\"Error: Failed loading from state dict.\")\n", - " \n", + "\n", "\n", "model.eval()\n", "model = model.to(dev)\n", @@ -2531,7 +2818,13 @@ "metadata": {}, "outputs": [], "source": [ - "spectral_modules = [\"node_embedding\", \"edge_embedding\", \"GNN_module\", \"edge_module\", \"precursor_module\"]\n", + "spectral_modules = [\n", + " \"node_embedding\",\n", + " \"edge_embedding\",\n", + " \"GNN_module\",\n", + " \"edge_module\",\n", + " \"precursor_module\",\n", + "]\n", "\n", "\n", "for module in spectral_modules:\n", @@ -2589,8 +2882,7 @@ } ], "source": [ - "df_train[\"RTorCCS\"] = ~(df_train[\"RETENTIONTIME\"].isna() & df_train[\"CCS\"].isna())\n", - "\n" + "df_train[\"RTorCCS\"] = ~(df_train[\"RETENTIONTIME\"].isna() & df_train[\"CCS\"].isna())" ] }, { @@ -2610,10 +2902,24 @@ ], "source": [ "rt_index = df_train.drop_duplicates(\"group_id\", keep=\"first\")[\"RTorCCS\"]\n", - "print(\"RT: \", sum(~df_train.drop_duplicates(\"group_id\", keep=\"first\")[rt_index][\"RETENTIONTIME\"].isna()))\n", - "print(\"CCS: \", sum(~df_train.drop_duplicates(\"group_id\", keep=\"first\")[rt_index][\"CCS\"].isna()))\n", - "\n", - "geo_data = df_train.drop_duplicates(\"group_id\", keep=\"first\")[rt_index][\"Metabolite\"].apply(lambda x: x.as_geometric_data().to(dev)).values\n", + "print(\n", + " \"RT: \",\n", + " sum(\n", + " ~df_train.drop_duplicates(\"group_id\", keep=\"first\")[rt_index][\n", + " \"RETENTIONTIME\"\n", + " ].isna()\n", + " ),\n", + ")\n", + "print(\n", + " \"CCS: \",\n", + " sum(~df_train.drop_duplicates(\"group_id\", keep=\"first\")[rt_index][\"CCS\"].isna()),\n", + ")\n", + "\n", + "geo_data = (\n", + " df_train.drop_duplicates(\"group_id\", keep=\"first\")[rt_index][\"Metabolite\"]\n", + " .apply(lambda x: x.as_geometric_data().to(dev))\n", + " .values\n", + ")\n", "print(f\"Prepared training/validation with {len(geo_data)} data points\")" ] }, @@ -2623,17 +2929,46 @@ "metadata": {}, "outputs": [], "source": [ - "rt_epochs = 500 # 300\n", - "rt_batch = 64 #128\n", + "rt_epochs = 500 # 300\n", + "rt_batch = 64 # 128\n", "rt_lr = 0.005\n", "\n", - "def train_rt_model(rt_lr=rt_lr, rt_batch=rt_batch, rt_epochs=rt_epochs): \n", - " y_label = 'compiled_probsALL'\n", + "\n", + "def train_rt_model(rt_lr=rt_lr, rt_batch=rt_batch, rt_epochs=rt_epochs):\n", + " y_label = \"compiled_probsALL\"\n", " optimizer = torch.optim.Adam(model.parameters(), lr=rt_lr)\n", - " train_keys, val_keys = df[df[\"dataset\"] == \"training\"][\"group_id\"].unique(), df[df[\"dataset\"] == \"validation\"][\"group_id\"].unique()\n", - " trainer = Trainer(geo_data, y_tag=y_label, problem_type=\"regression\", train_keys=train_keys, val_keys=val_keys, metric_dict=None, split_by_group=True, seed=seed, device=dev)\n", - " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 8, factor=0.8, mode = 'min', verbose = True)\n", - " checkpoints = trainer.train(model, optimizer, loss_fn, scheduler=scheduler, batch_size=rt_batch, epochs=rt_epochs, val_every_n_epochs=1, with_CCS=True, with_RT=True, rt_metric=True, masked_validation=False, tag=tag) #, mask_name=\"compiled_validation_maskALL\") \n", + " train_keys, val_keys = (\n", + " df[df[\"dataset\"] == \"training\"][\"group_id\"].unique(),\n", + " df[df[\"dataset\"] == \"validation\"][\"group_id\"].unique(),\n", + " )\n", + " trainer = Trainer(\n", + " geo_data,\n", + " y_tag=y_label,\n", + " problem_type=\"regression\",\n", + " train_keys=train_keys,\n", + " val_keys=val_keys,\n", + " metric_dict=None,\n", + " split_by_group=True,\n", + " seed=seed,\n", + " device=dev,\n", + " )\n", + " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optimizer, patience=8, factor=0.8, mode=\"min\", verbose=True\n", + " )\n", + " checkpoints = trainer.train(\n", + " model,\n", + " optimizer,\n", + " loss_fn,\n", + " scheduler=scheduler,\n", + " batch_size=rt_batch,\n", + " epochs=rt_epochs,\n", + " val_every_n_epochs=1,\n", + " with_CCS=True,\n", + " with_RT=True,\n", + " rt_metric=True,\n", + " masked_validation=False,\n", + " tag=tag,\n", + " ) # , mask_name=\"compiled_validation_maskALL\")\n", "\n", " return model, checkpoints" ] @@ -3237,7 +3572,7 @@ } ], "source": [ - "model, cp = train_rt_model() #Val RMSE 12.11 is too beat." + "model, cp = train_rt_model() # Val RMSE 12.11 is too beat." ] }, { @@ -3330,7 +3665,9 @@ "rt_batch = 256\n", "rt_lr = 0.00001\n", "model.set_dropout_rate(0.1, 0.05)\n", - "model, cp = train_rt_model(rt_lr=rt_lr, rt_batch=rt_batch, rt_epochs=rt_epochs) #Val RMSE 11.482 is too beat." + "model, cp = train_rt_model(\n", + " rt_lr=rt_lr, rt_batch=rt_batch, rt_epochs=rt_epochs\n", + ") # Val RMSE 11.482 is too beat." ] }, { @@ -3379,9 +3716,11 @@ } ], "source": [ - "val_df = test_model(model, df_train[df_train[\"dataset\"]== \"validation\"], return_df=True)\n", + "val_df = test_model(\n", + " model, df_train[df_train[\"dataset\"] == \"validation\"], return_df=True\n", + ")\n", "test_df = test_model(model, df_test, return_df=True)\n", - "casmi16_df = test_cas16(model, return_df=True) \n", + "casmi16_df = test_cas16(model, return_df=True)\n", "val_df[\"Dataset\"] = \"Val\"\n", "casmi16_df[\"Dataset\"] = \"CASMI 22\"\n", "test_df[\"Dataset\"] = \"Test split\"" @@ -3407,30 +3746,47 @@ "import seaborn as sns\n", "from fiora.visualization.define_colors import *\n", "\n", - "fig, axs = plt.subplots(2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True)\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "set_light_theme()\n", "\n", - "df_test_unique = test_df.drop_duplicates(subset=\"group_id\", keep='first')\n", + "df_test_unique = test_df.drop_duplicates(subset=\"group_id\", keep=\"first\")\n", "\n", "\n", - "#sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=df_test_unique, x=\"RETENTIONTIME\", bw_adjust=0.25, color=\"gray\", fill=True)#, multiple=\"stack\") #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=df_test_unique, x=\"RETENTIONTIME\", bw_adjust=0.25, color=\"gray\") #, multiple=\"stack\") #hue=\"Precursor_type\", \n", + "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=df_test_unique,\n", + " x=\"RETENTIONTIME\",\n", + " bw_adjust=0.25,\n", + " color=\"gray\",\n", + " fill=True,\n", + ") # , multiple=\"stack\") #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0], data=df_test_unique, x=\"RETENTIONTIME\", bw_adjust=0.25, color=\"gray\"\n", + ") # , multiple=\"stack\") #hue=\"Precursor_type\",\n", "\n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", "\n", "\n", - "sns.scatterplot(ax=axs[1], data=df_test_unique, x=\"RETENTIONTIME\", y=\"RT_pred\", color=\"gray\")#, hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", - "axs[1].set_ylim([0,df_test_unique[\"RETENTIONTIME\"].max() + 1 ])\n", - "axs[1].set_xlim([0,df_test_unique[\"RETENTIONTIME\"].max() + 1])\n", + "sns.scatterplot(\n", + " ax=axs[1], data=df_test_unique, x=\"RETENTIONTIME\", y=\"RT_pred\", color=\"gray\"\n", + ") # , hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", + "axs[1].set_ylim([0, df_test_unique[\"RETENTIONTIME\"].max() + 1])\n", + "axs[1].set_xlim([0, df_test_unique[\"RETENTIONTIME\"].max() + 1])\n", "axs[1].set_ylabel(\"Predicted retention time\")\n", "axs[1].set_xlabel(\"Observed retention time\")\n", "line = [0, 100]\n", "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", - "sns.lineplot(ax=axs[1], x=line, y=[x + 20/60.0 for x in line], color=\"black\", linestyle='--')\n", - "sns.lineplot(ax=axs[1], x=line, y=[x - 20/60.0 for x in line], color=\"black\", linestyle='--')\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[x + 20 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[x - 20 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", "plt.show()" ] }, @@ -3464,28 +3820,62 @@ ], "source": [ "# TODO NEXT UP!!\n", - "fig, axs = plt.subplots(2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True)\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", - "df_test_unique = test_df.drop_duplicates(subset='group_id', keep='first')\n", - "CCS = pd.concat([df_test_unique[[\"CCS\", \"CCS_pred\", \"Dataset\"]], casmi16_df[[\"CCS\", \"CCS_pred\", \"Dataset\"]]], ignore_index=True) #\n", - "\n", - "\n", - "#sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=CCS, x=\"CCS\", bw_adjust=0.35, color=\"black\", multiple=\"stack\", hue=\"Dataset\", palette=tri_palette, edgecolor=\"white\") #hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", - "\n", - "\n", - "sns.scatterplot(ax=axs[1], data=CCS, x=\"CCS\", y=\"CCS_pred\", hue=\"Dataset\", palette=tri_palette, style=\"Dataset\", markers=[\".\", \"X\", \"*\"], color=\"gray\", s=25, linewidth=.0)#, color=\"blue\", edgecolor=\"blue\")#, \n", - "axs[1].set_ylim([df_test_unique[\"CCS\"].min() - 10,df_test_unique[\"CCS\"].max() + 10])\n", - "axs[1].set_xlim([df_test_unique[\"CCS\"].min() - 10,df_test_unique[\"CCS\"].max() + 10])\n", + "df_test_unique = test_df.drop_duplicates(subset=\"group_id\", keep=\"first\")\n", + "CCS = pd.concat(\n", + " [\n", + " df_test_unique[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", + " casmi16_df[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", + " ],\n", + " ignore_index=True,\n", + ") #\n", + "\n", + "\n", + "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=CCS,\n", + " x=\"CCS\",\n", + " bw_adjust=0.35,\n", + " color=\"black\",\n", + " multiple=\"stack\",\n", + " hue=\"Dataset\",\n", + " palette=tri_palette,\n", + " edgecolor=\"white\",\n", + ") # hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", + "\n", + "\n", + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=CCS,\n", + " x=\"CCS\",\n", + " y=\"CCS_pred\",\n", + " hue=\"Dataset\",\n", + " palette=tri_palette,\n", + " style=\"Dataset\",\n", + " markers=[\".\", \"X\", \"*\"],\n", + " color=\"gray\",\n", + " s=25,\n", + " linewidth=0.0,\n", + ") # , color=\"blue\", edgecolor=\"blue\")#,\n", + "axs[1].set_ylim([df_test_unique[\"CCS\"].min() - 10, df_test_unique[\"CCS\"].max() + 10])\n", + "axs[1].set_xlim([df_test_unique[\"CCS\"].min() - 10, df_test_unique[\"CCS\"].max() + 10])\n", "axs[1].set_ylabel(\"Predicted CCS\")\n", "axs[1].set_xlabel(\"Observed CCS\")\n", - "line=[df_test_unique[\"CCS\"].min() - 10,df_test_unique[\"CCS\"].max() + 10]\n", + "line = [df_test_unique[\"CCS\"].min() - 10, df_test_unique[\"CCS\"].max() + 10]\n", "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", - "sns.lineplot(ax=axs[1], x=line, y=[1.1*x for x in line], color=\"black\", linestyle='--')\n", - "sns.lineplot(ax=axs[1], x=line, y=[0.9*x for x in line], color=\"black\", linestyle='--')\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[1.1 * x for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[0.9 * x for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", "plt.show()" ] }, @@ -3520,24 +3910,79 @@ } ], "source": [ - "\n", "print(\"TEST SPLIT:\\n\")\n", "print(\"Pearson Corr Coef:\")\n", - "print(\"GNN\", np.corrcoef(df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"], df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(), dtype=float)[0,1])\n", - "print(\"LR \", np.corrcoef(df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"], df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(), dtype=float)[0,1])\n", - "\n", - "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(df_train.dropna(subset=[\"CCS\"])[\"PRECURSORMZ\"], df_train.dropna(subset=[\"CCS\"])[\"CCS\"])\n", + "print(\n", + " \"GNN\",\n", + " np.corrcoef(\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", + "print(\n", + " \"LR \",\n", + " np.corrcoef(\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", + "\n", + "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", + " df_train.dropna(subset=[\"CCS\"])[\"PRECURSORMZ\"],\n", + " df_train.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + ")\n", "print(\"R2\")\n", - "print(\"GNN\", r2_score(df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"], df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna()))\n", - "print(\"LR \", r2_score(df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"], intercept + slope *df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna()))\n", + "print(\n", + " \"GNN\",\n", + " r2_score(\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " ),\n", + ")\n", + "print(\n", + " \"LR \",\n", + " r2_score(\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " intercept\n", + " + slope * df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " ),\n", + ")\n", "\n", "print(\"---------------\\n\\nCASMI-16:\\n\")\n", "print(\"Pearson Corr Coef:\")\n", - "print(\"GNN\", np.corrcoef(casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"], casmi16_df.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(), dtype=float)[0,1])\n", - "print(\"LR \", np.corrcoef(casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"], casmi16_df.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(), dtype=float)[0,1])\n", + "print(\n", + " \"GNN\",\n", + " np.corrcoef(\n", + " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", + "print(\n", + " \"LR \",\n", + " np.corrcoef(\n", + " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " casmi16_df.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", "print(\"R2\")\n", - "print(\"GNN\", r2_score(casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"], casmi16_df.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna()))\n", - "print(\"LR \", r2_score(casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"], intercept + slope *casmi16_df.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna()))" + "print(\n", + " \"GNN\",\n", + " r2_score(\n", + " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " ),\n", + ")\n", + "print(\n", + " \"LR \",\n", + " r2_score(\n", + " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " intercept + slope * casmi16_df.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " ),\n", + ")" ] } ], diff --git a/notebooks/test_model.ipynb b/notebooks/test_model.ipynb index 002ec4e..d9c570a 100644 --- a/notebooks/test_model.ipynb +++ b/notebooks/test_model.ipynb @@ -32,7 +32,7 @@ "\n", "\n", "seed = 42\n", - "#torch.set_default_dtype(torch.float64)\n", + "# torch.set_default_dtype(torch.float64)\n", "torch.manual_seed(seed)\n", "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", @@ -41,12 +41,13 @@ "import numpy as np\n", "import ast\n", "import copy\n", - "import matplotlib.pyplot as plt \n", + "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "from fiora.MOL.constants import *\n", "from fiora.IO.LibraryLoader import LibraryLoader\n", @@ -56,9 +57,10 @@ "from sklearn.metrics import r2_score\n", "import scipy\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f'Working with Python {sys.version}')\n" + "RDLogger.DisableLog(\"rdApp.*\")\n", + "\n", + "print(f\"Working with Python {sys.version}\")" ] }, { @@ -80,13 +82,13 @@ "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/pre_package/v0.0.1_2_OS_depth{depth}_June24+CCS+RT.pt\" # OS model (first try)\n", "\n", "# NEW AND SHINY\n", - "MODEL_PATH = f\"../models/fiora_OS_v1.0.0.pt\" # Release version\n", + "MODEL_PATH = f\"../models/fiora_OS_v1.0.0.pt\" # Release version\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v1.0.0_OS_depth10_Sep25_x4.pt\"\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_depth{depth}_Aug24_sqrt+CCS+RT_drop3.pt\" # New sqrt model (improved) | Note: drop3 uses dropout reduction while training RT, CCS\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_OS_depth{depth}_Aug24_sqrt_4.pt\" # or Aug24_sqrt_x are new OS models\n", "\n", - "#v: str = \"0.1.0\"\n", - "#MODEL_PATH = f\"../models/fiora_OS_v{v}.pt\" # Release version\n", + "# v: str = \"0.1.0\"\n", + "# MODEL_PATH = f\"../models/fiora_OS_v{v}.pt\" # Release version\n", "\n", "from fiora.GNN.FioraModel import FioraModel\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", @@ -95,8 +97,7 @@ "try:\n", " model = FioraModel.load_from_state_dict(MODEL_PATH)\n", "except:\n", - " raise NameError(\"Error: Failed loading from state dict.\")\n", - " " + " raise NameError(\"Error: Failed loading from state dict.\")" ] }, { @@ -137,7 +138,7 @@ "# key map to read metadata from pandas DataFrame\n", "# metadata_key_map = {\n", "# \"name\": \"Name\",\n", - "# \"collision_energy\": \"CE\", \n", + "# \"collision_energy\": \"CE\",\n", "# \"instrument\": \"Instrument_type\",\n", "# \"ionization\": \"Ionization\",\n", "# \"precursor_mz\": \"PrecursorMZ\",\n", @@ -151,14 +152,18 @@ "# Load specified libraries and align metadata\n", "#\n", "\n", + "\n", "def load_training_data():\n", " L = LibraryLoader()\n", " df = L.load_from_csv(f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\")\n", " return df\n", "\n", + "\n", "def load_msnlib():\n", " L = LibraryLoader()\n", - " df = L.load_from_csv(f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_Aug24_v3.csv\")\n", + " df = L.load_from_csv(\n", + " f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_Aug24_v3.csv\"\n", + " )\n", " return df\n", "\n", "\n", @@ -168,11 +173,13 @@ "# Restore dictionary values\n", "dict_columns = [\"peaks\", \"summary\"]\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", - " df_msnlib[col] = df_msnlib[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", - " #df[col] = df[col].apply(ast.literal_eval)\n", - " \n", - "df['group_id'] = df['group_id'].astype(int)\n" + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " df_msnlib[col] = df_msnlib[col].apply(\n", + " lambda x: ast.literal_eval(x.replace(\"nan\", \"None\"))\n", + " )\n", + " # df[col] = df[col].apply(ast.literal_eval)\n", + "\n", + "df[\"group_id\"] = df[\"group_id\"].astype(int)" ] }, { @@ -199,7 +206,7 @@ "source": [ "print(df.groupby(\"lib\")[\"group_id\"].unique().apply(len))\n", "print(df[\"lib\"].value_counts())\n", - "print(len(df[\"group_id\"].unique()))\n" + "print(len(df[\"group_id\"].unique()))" ] }, { @@ -228,7 +235,7 @@ "df_test = df[df[\"dataset\"] == \"test\"]\n", "\n", "df_msnlib_train = df_msnlib[df_msnlib[\"dataset\"] != \"test\"]\n", - "df_msnlib_test = df_msnlib[df_msnlib[\"dataset\"] == \"test\"]\n" + "df_msnlib_test = df_msnlib[df_msnlib[\"dataset\"] == \"test\"]" ] }, { @@ -237,7 +244,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "from fiora.MOL.Metabolite import Metabolite\n", "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", @@ -251,35 +257,66 @@ "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", "model_setup_feature_sets = None\n", "if \"setup_features_categorical_set\" in model.model_params.keys():\n", - " model_setup_feature_sets = model.model_params[\"setup_features_categorical_set\"] \n", + " model_setup_feature_sets = model.model_params[\"setup_features_categorical_set\"]\n", " # TODO Refactor this:\n", " for i, data in df_test.iterrows():\n", - " if df_test.loc[i][\"summary\"][\"instrument\"] not in model_setup_feature_sets[\"instrument\"]:\n", + " if (\n", + " df_test.loc[i][\"summary\"][\"instrument\"]\n", + " not in model_setup_feature_sets[\"instrument\"]\n", + " ):\n", " df_test.loc[i][\"summary\"][\"instrument\"] = \"HCD\"\n", - "covariate_encoder = CovariateFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\", \"element_composition\"], sets_overwrite=model_setup_feature_sets)\n", - "rt_encoder = CovariateFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"], sets_overwrite=model_setup_feature_sets)\n", + "covariate_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " \"element_composition\",\n", + " ],\n", + " sets_overwrite=model_setup_feature_sets,\n", + ")\n", + "rt_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " sets_overwrite=model_setup_feature_sets,\n", + ")\n", "\n", "\n", "def process_dataframes(df_train, df_test):\n", "\n", - " df_train[\"Metabolite\"] = df_train[\"SMILES\"].apply(Metabolite) # TRAIN Metabolites are only tracked for tanimoto distance\n", + " df_train[\"Metabolite\"] = df_train[\"SMILES\"].apply(\n", + " Metabolite\n", + " ) # TRAIN Metabolites are only tracked for tanimoto distance\n", " df_test[\"Metabolite\"] = df_test[\"SMILES\"].apply(Metabolite)\n", " df_test[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", + " covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + " covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + " rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", "\n", - " covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - " covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - " rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - "\n", - " df_test[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - " df_test.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]) , axis=1)\n", - "\n", - " #df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - " df_test.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", - " df_train.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], process_metadata=False), axis=1)\n", + " df_test[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + " )\n", + " df_test.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", + "\n", + " # df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", + " df_test.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(\n", + " x[\"summary\"], covariate_encoder, rt_encoder\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_train.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], process_metadata=False),\n", + " axis=1,\n", + " )\n", "\n", " df_test[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - " df_test.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]), axis=1)\n", + " df_test.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " ),\n", + " axis=1,\n", + " )\n", "\n", " return df_train, df_test" ] @@ -330,7 +367,11 @@ "casmi22_path = f\"{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv\"\n", "\n", "df_cas = pd.read_csv(casmi16_path, index_col=[0], low_memory=False)\n", - "df_cast = pd.read_csv(f\"{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv\", index_col=[0], low_memory=False) # f\"{home}/data/metabolites/CASMI_2016/casmi16_training_combined.csv\"\n", + "df_cast = pd.read_csv(\n", + " f\"{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv\",\n", + " index_col=[0],\n", + " low_memory=False,\n", + ") # f\"{home}/data/metabolites/CASMI_2016/casmi16_training_combined.csv\"\n", "df_cas22 = pd.read_csv(casmi22_path, index_col=[0], low_memory=False)\n", "\n", "# Restore dictionary values\n", @@ -342,12 +383,12 @@ "df_cas[\"is_priority\"] = True\n", "df_cast[\"is_priority\"] = False\n", "df_cas22[\"peaks\"] = df_cas22[\"peaks\"].apply(ast.literal_eval)\n", - "df_cas22[\"ChallengeNum\"] = df_cas22[\"ChallengeName\"].apply(lambda x: int(x.split(\"-\")[-1]))\n", + "df_cas22[\"ChallengeNum\"] = df_cas22[\"ChallengeName\"].apply(\n", + " lambda x: int(x.split(\"-\")[-1])\n", + ")\n", "df_cas22[\"is_priority\"] = (df_cas22[\"ChallengeNum\"] < 250).astype(bool)\n", "\n", "\n", - "\n", - "\n", "def closest_cfm_ce(CE):\n", " ref = np.array([10, 20, 40])\n", " abs_error = abs(ref - CE)\n", @@ -429,23 +470,35 @@ "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", "df_cas[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", - "\n", - "metadata_key_map16 = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\"\n", - " }\n", - "\n", - "df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - "df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder), axis=1)\n", + "df_cas[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", + "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "\n", + "metadata_key_map16 = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"PRECURSOR_MZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + "}\n", + "\n", + "df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + ")\n", + "df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder), axis=1\n", + ")\n", "\n", "# Fragmentation\n", "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_cas.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=300 * PPM), axis=1)\n", + "df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=300 * PPM\n", + " ),\n", + " axis=1,\n", + ")\n", "\n", "#\n", "# CASMI 22\n", @@ -455,30 +508,45 @@ "df_cas22[\"Metabolite\"] = df_cas22[\"SMILES\"].apply(Metabolite)\n", "df_cas22[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cas22[\"CE\"] = df_cas22.apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1)\n", - "df_cas22 = df_cas22[df_cas22[\"CE\"] < CE_upper_limit] \n", - "df_cas22 = df_cas22[df_cas22[\"CE\"] > 0] \n", - "#df_cas22 = df_cas22[df_cas22.is_priority]\n", - "\n", - "metadata_key_map22 = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"precursor_mz\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"ChallengeRT\"\n", - " }\n", - "\n", - "df_cas22[\"summary\"] = df_cas22.apply(lambda x: {key: x[name] for key, name in metadata_key_map22.items()}, axis=1)\n", - "df_cas22.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", + "df_cas22[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cas22[\"CE\"] = df_cas22.apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1\n", + ")\n", + "df_cas22 = df_cas22[df_cas22[\"CE\"] < CE_upper_limit]\n", + "df_cas22 = df_cas22[df_cas22[\"CE\"] > 0]\n", + "# df_cas22 = df_cas22[df_cas22.is_priority]\n", + "\n", + "metadata_key_map22 = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"precursor_mz\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"ChallengeRT\",\n", + "}\n", + "\n", + "df_cas22[\"summary\"] = df_cas22.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map22.items()}, axis=1\n", + ")\n", + "df_cas22.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "\n", "# Fragmentation\n", "df_cas22[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "df_cas22.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=300 * PPM), axis=1) # Optional: use mz_cut instead\n", + "df_cas22.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=300 * PPM\n", + " ),\n", + " axis=1,\n", + ") # Optional: use mz_cut instead\n", "\n", "df_cas22 = df_cas22.reset_index()\n", "\n", "df_cas22[\"library\"] = \"CASMI-22\"\n", - "df_cas22[\"RETENTIONTIME\"] = df_cas22[\"ChallengeRT\"] # \"RT_min\"\n", + "df_cas22[\"RETENTIONTIME\"] = df_cas22[\"ChallengeRT\"] # \"RT_min\"\n", "df_cas22[\"cfm_CE\"] = df_cas22[\"CE\"].apply(closest_cfm_ce)\n", "df_cas22[[\"NCE\", \"CE\", \"cfm_CE\"]].head(3)" ] @@ -496,23 +564,35 @@ "df_cast[\"Metabolite\"] = df_cast[\"SMILES\"].apply(Metabolite)\n", "df_cast[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cast[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cast[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cast[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", - "\n", - "metadata_key_map16 = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\"\n", - " }\n", - "\n", - "df_cast[\"summary\"] = df_cast.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - "df_cast.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder), axis=1)\n", + "df_cast[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cast[\"CE\"] = 20.0 # actually stepped 20/35/50\n", + "df_cast[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "\n", + "metadata_key_map16 = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"PRECURSOR_MZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + "}\n", + "\n", + "df_cast[\"summary\"] = df_cast.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + ")\n", + "df_cast.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder), axis=1\n", + ")\n", "\n", "# Fragmentation\n", "df_cast[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "_ = df_cast.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=300 * PPM), axis=1)\n" + "_ = df_cast.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=300 * PPM\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -540,14 +620,12 @@ "from fiora.GNN.Trainer import Trainer\n", "import torch_geometric as geom\n", "\n", - "if torch.cuda.is_available(): \n", - " dev = \"cuda:0\"\n", - "else: \n", - " dev = \"cpu\" \n", + "if torch.cuda.is_available():\n", + " dev = \"cuda:0\"\n", + "else:\n", + " dev = \"cpu\"\n", "\n", - "print(f\"Running on device: {dev}\")\n", - "\n", - "\n" + "print(f\"Running on device: {dev}\")" ] }, { @@ -615,8 +693,6 @@ } ], "source": [ - "\n", - "\n", "model.eval()\n", "model = model.to(dev)\n", "\n", @@ -632,11 +708,13 @@ "metadata": {}, "outputs": [], "source": [ - "np.seterr(invalid='ignore')\n", + "np.seterr(invalid=\"ignore\")\n", + "\n", + "\n", "def simulate_all(model, DF):\n", " return fiora.simulate_all(DF, model)\n", "\n", - " \n", + "\n", "def test_model(model, DF):\n", " dft = simulate_all(model, DF)\n", " return dft" @@ -656,53 +734,126 @@ "outputs": [], "source": [ "from fiora.MOL.collision_energy import NCE_to_eV\n", - "from fiora.MS.spectral_scores import spectral_cosine, spectral_reflection_cosine, reweighted_dot\n", + "from fiora.MS.spectral_scores import (\n", + " spectral_cosine,\n", + " spectral_reflection_cosine,\n", + " reweighted_dot,\n", + ")\n", "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", "\n", "def test_cas16(model, df_cas=df_cas, ignore_RT=False):\n", - " \n", - " \n", + "\n", " # Predict spectra for first NCE step\n", " df_cas[\"NCE\"] = 20.0\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(\n", + " x[\"summary\"], covariate_encoder, rt_encoder\n", + " ),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", "\n", " # Predict spectra for last NCE step\n", - " df_cas[\"NCE\"] = 50.0 \n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 50.0\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(\n", + " x[\"summary\"], covariate_encoder, rt_encoder\n", + " ),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", "\n", " # Predict spectra for middle (average) NCE step (doing this last makes sure Metabolite metadata references the average case)\n", " df_cas[\"NCE\"] = 35.0\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(\n", + " x[\"summary\"], covariate_encoder, rt_encoder\n", + " ),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", "\n", - " df_cas[\"avg_CE\"] = (df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]) / 3\n", + " df_cas[\"avg_CE\"] = (\n", + " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " ) / 3\n", " df_cas[\"CE\"] = df_cas[\"avg_CE\"]\n", - " \n", - " df_cas[\"merged_peaks\"] = df_cas.apply(lambda x: merge_annotated_spectrum(merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]), x[\"sim_peaks_50\"]) , axis=1)\n", - " df_cas[\"sim_peaks\"] = df_cas[\"merged_peaks\"] \n", - " df_cas[\"merged_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_sqrt_bias\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt, with_bias=True)[1], axis=1)\n", - " df_cas[\"merged_sqrt_cosine_wo_precursor\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt, remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(x[\"Metabolite\"].metadata[\"precursor_mode\"])), axis=1)\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(lambda x: spectral_reflection_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_steins\"] = df_cas.apply(lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"spectral_cosine\"] = df_cas[\"merged_cosine\"] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\"merged_sqrt_cosine\"] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine_wo_prec\"] = df_cas[\"merged_sqrt_cosine_wo_precursor\"] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine_avg\"] = (df_cas[\"spectral_sqrt_cosine\"] + df_cas[\"spectral_sqrt_cosine_wo_prec\"]) / 2.0\n", - " df_cas[\"spectral_sqrt_bias\"] = df_cas[\"merged_sqrt_bias\"] # just remember it is merged\n", + "\n", + " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " lambda x: merge_annotated_spectrum(\n", + " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", + " x[\"sim_peaks_50\"],\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"sim_peaks\"] = df_cas[\"merged_peaks\"]\n", + " df_cas[\"merged_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_sqrt_bias\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(\n", + " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt, with_bias=True\n", + " )[1],\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_sqrt_cosine_wo_precursor\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(\n", + " x[\"peaks\"],\n", + " x[\"merged_peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " x[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_reflection_cosine(\n", + " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_steins\"] = df_cas.apply(\n", + " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"spectral_cosine\"] = df_cas[\"merged_cosine\"] # just remember it is merged\n", + " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", + " \"merged_sqrt_cosine\"\n", + " ] # just remember it is merged\n", + " df_cas[\"spectral_sqrt_cosine_wo_prec\"] = df_cas[\n", + " \"merged_sqrt_cosine_wo_precursor\"\n", + " ] # just remember it is merged\n", + " df_cas[\"spectral_sqrt_cosine_avg\"] = (\n", + " df_cas[\"spectral_sqrt_cosine\"] + df_cas[\"spectral_sqrt_cosine_wo_prec\"]\n", + " ) / 2.0\n", + " df_cas[\"spectral_sqrt_bias\"] = df_cas[\n", + " \"merged_sqrt_bias\"\n", + " ] # just remember it is merged\n", "\n", " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", " if not ignore_RT:\n", @@ -712,9 +863,8 @@ "\n", " df_cas[\"library\"] = \"CASMI-16\"\n", "\n", - " \n", " df_cas[\"cfm_CE\"] = df_cas[\"avg_CE\"].apply(closest_cfm_ce)\n", - " \n", + "\n", " return df_cas" ] }, @@ -15622,7 +15772,7 @@ ], "source": [ "print(f\"Testing the model\")\n", - "np.seterr(invalid='ignore')\n", + "np.seterr(invalid=\"ignore\")\n", "df_test = test_model(model, df_test)\n", "df_msnlib_test = test_model(model, df_msnlib_test)\n", "df_cas = test_cas16(model, ignore_RT=True)\n", @@ -15704,10 +15854,16 @@ "source": [ "import fiora.IO.cfmReader as cfmReader\n", "\n", - "cf = cfmReader.read(f\"{home}/data/metabolites/cfm-id/msnlib_test_split_negative_predictions.txt\", as_df=True)\n", - "cf_p = cfmReader.read(f\"{home}/data/metabolites/cfm-id/msnlib_test_split_positive_predictions.txt\", as_df=True)\n", - "cf[\"ion_type\"] = \"negative\" \n", - "cf_p[\"ion_type\"] = \"positive\" \n", + "cf = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/msnlib_test_split_negative_predictions.txt\",\n", + " as_df=True,\n", + ")\n", + "cf_p = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/msnlib_test_split_positive_predictions.txt\",\n", + " as_df=True,\n", + ")\n", + "cf[\"ion_type\"] = \"negative\"\n", + "cf_p[\"ion_type\"] = \"positive\"\n", "cf = pd.concat([cf, cf_p])\n", "cf[\"#ID\"] = cf[\"#ID\"].astype(int)\n", "df_msnlib_test[\"cfm_CE\"] = df_msnlib_test[\"CE\"].apply(closest_cfm_ce)\n", @@ -15720,16 +15876,28 @@ " if len(cf[(cf[\"#ID\"] == group_id) & (cf[\"Precursor_type\"] == precursor_type)]) < 1:\n", " print(f\"{group_id} not found in CFM-ID results. Skipping.\")\n", " continue\n", - " \n", - " cfm_data = cf[(cf[\"#ID\"] == group_id) & (cf[\"Precursor_type\"] == precursor_type)].iloc[0]\n", - " \n", - " \n", - " cfm_peaks = cfm_data[\"peaks\" + data[\"cfm_CE\"]] # find best reference CE\n", + "\n", + " cfm_data = cf[\n", + " (cf[\"#ID\"] == group_id) & (cf[\"Precursor_type\"] == precursor_type)\n", + " ].iloc[0]\n", + "\n", + " cfm_peaks = cfm_data[\"peaks\" + data[\"cfm_CE\"]] # find best reference CE\n", " df_msnlib_test.at[i, \"cfm_peaks\"] = cfm_peaks\n", " df_msnlib_test.at[i, \"cfm_cosine\"] = spectral_cosine(data[\"peaks\"], cfm_peaks)\n", - " df_msnlib_test.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(data[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", - " df_msnlib_test.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(data[\"peaks\"], cfm_peaks, transform=np.sqrt, remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_msnlib_test.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(data[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", + " df_msnlib_test.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", + " data[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", + " df_msnlib_test.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " data[\"peaks\"],\n", + " cfm_peaks,\n", + " transform=np.sqrt,\n", + " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_msnlib_test.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", + " data[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", " df_msnlib_test.at[i, \"cfm_steins\"] = reweighted_dot(data[\"peaks\"], cfm_peaks)" ] }, @@ -15753,10 +15921,14 @@ "import fiora.IO.cfmReader as cfmReader\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", - "cf = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi16_negative_predictions.txt\", as_df=True)\n", - "cf_p = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt\", as_df=True)\n", + "cf = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi16_negative_predictions.txt\", as_df=True\n", + ")\n", + "cf_p = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt\", as_df=True\n", + ")\n", "cf = pd.concat([cf, cf_p])\n", - "len(cf[cf[\"#ID\"] == \"Challenge-009\"]) ## missing chalenges\n", + "len(cf[cf[\"#ID\"] == \"Challenge-009\"]) ## missing chalenges\n", "df_cas[\"cfm_peaks\"] = None\n", "df_cas[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"cfm_refl_cosine\"]] = np.nan\n", "for i, cas in df_cas.iterrows():\n", @@ -15766,16 +15938,26 @@ " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", " continue\n", " cfm_data = cf[cf[\"#ID\"] == challenge].iloc[0]\n", - " \n", - " \n", + "\n", " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", + " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", " df_cas.at[i, \"cfm_peaks\"] = cfm_peaks\n", " df_cas.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cas.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", - " df_cas.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt, remove_mz=cas[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=cas[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_cas.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", + " df_cas.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", + " df_cas.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " cas[\"peaks\"],\n", + " cfm_peaks,\n", + " transform=np.sqrt,\n", + " remove_mz=cas[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=cas[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_cas.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", " df_cas.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)" ] }, @@ -15799,12 +15981,18 @@ "import fiora.IO.cfmReader as cfmReader\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", - "cf = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi16t_negative_predictions.txt\", as_df=True)\n", - "cf_p = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi16t_positive_predictions.txt\", as_df=True)\n", + "cf = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi16t_negative_predictions.txt\", as_df=True\n", + ")\n", + "cf_p = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi16t_positive_predictions.txt\", as_df=True\n", + ")\n", "cf = pd.concat([cf, cf_p])\n", - "len(cf[cf[\"#ID\"] == \"Challenge-009\"]) ## missing chalenges\n", + "len(cf[cf[\"#ID\"] == \"Challenge-009\"]) ## missing chalenges\n", "df_cast[\"cfm_peaks\"] = None\n", - "df_cast[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"cfm_refl_cosine\"]] = np.nan\n", + "df_cast[\n", + " [\"cfm_cosine\", \"cfm_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"cfm_refl_cosine\"]\n", + "] = np.nan\n", "for i, cas in df_cast.iterrows():\n", " challenge = cas[\"ChallengeName\"]\n", "\n", @@ -15812,16 +16000,26 @@ " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", " continue\n", " cfm_data = cf[cf[\"#ID\"] == challenge].iloc[0]\n", - " \n", - " \n", + "\n", " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", + " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", " df_cast.at[i, \"cfm_peaks\"] = cfm_peaks\n", " df_cast.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cast.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", - " df_cast.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt, remove_mz=cas[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=cas[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_cast.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", + " df_cast.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", + " df_cast.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " cas[\"peaks\"],\n", + " cfm_peaks,\n", + " transform=np.sqrt,\n", + " remove_mz=cas[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=cas[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_cast.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", " df_cast.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)" ] }, @@ -15835,9 +16033,12 @@ "# time CFM-ID 4: -> 12m16,571s\n", "\n", "\n", - "\n", - "cf22 = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi22_negative_predictions.txt\", as_df=True)\n", - "cf22_p = cfmReader.read(f\"{home}/data/metabolites/cfm-id/casmi22_positive_predictions.txt\", as_df=True)\n", + "cf22 = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi22_negative_predictions.txt\", as_df=True\n", + ")\n", + "cf22_p = cfmReader.read(\n", + " f\"{home}/data/metabolites/cfm-id/casmi22_positive_predictions.txt\", as_df=True\n", + ")\n", "cf22 = pd.concat([cf22, cf22_p])\n", "df_cas22[\"cfm_peaks\"] = None\n", "df_cas22[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"cfm_refl_cosine\"]] = np.nan\n", @@ -15848,18 +16049,28 @@ " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", " continue\n", " cfm_data = cf22[cf22[\"#ID\"] == challenge].iloc[0]\n", - " \n", - " \n", + "\n", " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", + " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", " df_cas22.at[i, \"cfm_peaks\"] = cfm_peaks\n", " df_cas22.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cas22.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", - " df_cas22.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt, remove_mz=cas[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=cas[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_cas22.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(cas[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", + " df_cas22.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", + " df_cas22.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " cas[\"peaks\"],\n", + " cfm_peaks,\n", + " transform=np.sqrt,\n", + " remove_mz=cas[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=cas[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_cas22.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", + " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", " df_cas22.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)\n", - " \n", + "\n", "df_cas22[\"is_priority\"] = df_cas22[\"is_priority\"].astype(bool)" ] }, @@ -15869,11 +16080,9 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "ex_smiles = \"CC(NC(=O)CC1=CNC2=C1C=CC=C2)C(O)=O\"\n", "ex_metabolite = Metabolite(ex_smiles)\n", - "ex_compound_id = df_test[df_test[\"Metabolite\"] == ex_metabolite][\"group_id\"].iloc[0]\n", - "\n" + "ex_compound_id = df_test[df_test[\"Metabolite\"] == ex_metabolite][\"group_id\"].iloc[0]" ] }, { @@ -15918,37 +16127,46 @@ ], "source": [ "from fiora.visualization.define_colors import reset_matplotlib\n", + "\n", "reset_matplotlib()\n", "spec_df = {}\n", "for i, data in df_test[df_test[\"group_id\"] == ex_compound_id].iterrows():\n", " cosine = data[\"spectral_sqrt_cosine\"]\n", " name = data[\"Name\"]\n", - " #t3 = data[\"tanimoto3\"]\n", + " # t3 = data[\"tanimoto3\"]\n", " print(f\"{name} ({i}): cosine {cosine:0.2}\")\n", " # print(f\"{name} ({i}): cosine {t3}\") only possible after tanimoto calculation below\n", "\n", - " fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - " img = data[\"Metabolite\"].draw(ax= axs[0])\n", + " fig, axs = plt.subplots(\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " )\n", + " img = data[\"Metabolite\"].draw(ax=axs[0])\n", "\n", - " #axs[0].grid(False)\n", - " axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", + " # axs[0].grid(False)\n", + " axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " )\n", " axs[0].set_title(data[\"Name\"])\n", - " #axs[0].imshow(img)\n", - " #axs[0].axis(\"off\")\n", - " #sv.plot_spectrum(example, ax=axs[1])\n", + " # axs[0].imshow(img)\n", + " # axs[0].axis(\"off\")\n", + " # sv.plot_spectrum(example, ax=axs[1])\n", " prec = data[\"Precursor_type\"]\n", - " spec_df.update({\n", - " f\"Experimental m/z {prec}\": data[\"peaks\"][\"mz\"],\n", - " f\"Experimental intensity {prec}\": data[\"peaks\"][\"intensity\"],\n", - " f\"Fiora m/z {prec}\": data[\"sim_peaks\"][\"mz\"],\n", - " f\"Fiora intensity {prec}\": data[\"sim_peaks\"][\"intensity\"],\n", - " })\n", - " \n", - " ax = sv.plot_spectrum(data, {\"peaks\": data[\"sim_peaks\"]}, ax=axs[1], highlight_matches=False)\n", + " spec_df.update(\n", + " {\n", + " f\"Experimental m/z {prec}\": data[\"peaks\"][\"mz\"],\n", + " f\"Experimental intensity {prec}\": data[\"peaks\"][\"intensity\"],\n", + " f\"Fiora m/z {prec}\": data[\"sim_peaks\"][\"mz\"],\n", + " f\"Fiora intensity {prec}\": data[\"sim_peaks\"][\"intensity\"],\n", + " }\n", + " )\n", + "\n", + " ax = sv.plot_spectrum(\n", + " data, {\"peaks\": data[\"sim_peaks\"]}, ax=axs[1], highlight_matches=False\n", + " )\n", " # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - " plt.show()\n" + " plt.show()" ] }, { @@ -15957,7 +16175,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "spec_df = pd.DataFrame.from_dict(spec_df, orient=\"index\").transpose()\n", "spec_df = spec_df.fillna(\"-\")" ] @@ -15968,7 +16185,7 @@ "metadata": {}, "outputs": [], "source": [ - "#spec_df.to_excel(f\"{home}/images/paper/SF2.xlsx\")" + "# spec_df.to_excel(f\"{home}/images/paper/SF2.xlsx\")" ] }, { @@ -15987,8 +16204,10 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "paths = {\"[M-H]-\": f\"{home}/data/metabolites/cfm-id/test_pred_neg/\", \"[M+H]+\": f\"{home}/data/metabolites/cfm-id/test_pred_pos/\"}\n", + "paths = {\n", + " \"[M-H]-\": f\"{home}/data/metabolites/cfm-id/test_pred_neg/\",\n", + " \"[M+H]+\": f\"{home}/data/metabolites/cfm-id/test_pred_pos/\",\n", + "}\n", "df_test[\"cfm_CE\"] = df_test[\"CE\"].apply(closest_cfm_ce)\n", "df_test[\"cfm_peaks\"] = None\n", "df_test[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"cfm_refl_cosine\"]] = np.nan\n", @@ -15996,19 +16215,30 @@ " group_id = data[\"group_id\"]\n", " p = paths[data[\"Precursor_type\"]] + str(int(group_id)) + \".txt\"\n", " cfm_data = cfmReader.read(p, as_df=True)\n", - " \n", + "\n", " # TODO Check smiles / MOL\n", - " if cfm_data.shape == (0, 0): # Not predicted by CFM-ID\n", + " if cfm_data.shape == (0, 0): # Not predicted by CFM-ID\n", " continue\n", - " \n", + "\n", " cfm_peaks = cfm_data[\"peaks\" + data[\"cfm_CE\"]].iloc[0]\n", "\n", " df_test.at[i, \"cfm_peaks\"] = cfm_peaks\n", " df_test.at[i, \"cfm_cosine\"] = spectral_cosine(data[\"peaks\"], cfm_peaks)\n", - " df_test.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(data[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", - " df_test.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(data[\"peaks\"], cfm_peaks, transform=np.sqrt, remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_test.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(data[\"peaks\"], cfm_peaks, transform=np.sqrt)\n", - " df_test.at[i, \"cfm_steins\"] = reweighted_dot(data[\"peaks\"], cfm_peaks)\n" + " df_test.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", + " data[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", + " df_test.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " data[\"peaks\"],\n", + " cfm_peaks,\n", + " transform=np.sqrt,\n", + " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_test.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", + " data[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " )\n", + " df_test.at[i, \"cfm_steins\"] = reweighted_dot(data[\"peaks\"], cfm_peaks)" ] }, { @@ -16025,7 +16255,7 @@ "metadata": {}, "outputs": [], "source": [ - "# This section is dedicated to generate the reference libraries and inputs for CFM-ID and ICEBERG. \n", + "# This section is dedicated to generate the reference libraries and inputs for CFM-ID and ICEBERG.\n", "# Turn on only if training or test datasets change.\n", "\n", "\n", @@ -16033,73 +16263,117 @@ "if False:\n", " file = f\"{home}/data/metabolites/cfm-id/test_split_negative_solutions_cfm.txt\"\n", " df_test[\"group_id\"] = df_test[\"group_id\"].astype(int)\n", - " df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][[\"group_id\", \"SMILES\"]].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", + " df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][\n", + " [\"group_id\", \"SMILES\"]\n", + " ].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", " file = file.replace(\"negative\", \"positive\")\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][[\"group_id\", \"SMILES\"]].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " [\"group_id\", \"SMILES\"]\n", + " ].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", "\n", - "if False: \n", - " file = f\"{home}/data/metabolites/cfm-id/msnlib_test_split_negative_solutions_cfm.txt\"\n", + "if False:\n", + " file = (\n", + " f\"{home}/data/metabolites/cfm-id/msnlib_test_split_negative_solutions_cfm.txt\"\n", + " )\n", " df_msnlib_test[\"group_id\"] = df_msnlib_test[\"group_id\"].astype(int)\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][[\"group_id\", \"SMILES\"]].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][\n", + " [\"group_id\", \"SMILES\"]\n", + " ].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", " file = file.replace(\"negative\", \"positive\")\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][[\"group_id\", \"SMILES\"]].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " [\"group_id\", \"SMILES\"]\n", + " ].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", "\n", "# ICEBERG/SCARF training/testing input\n", "if False:\n", " # OLD # df_test[\"idx\"] = [f\"spec{i}\" for i,_ in df_test.iterrows()]\n", " df_test[\"num\"] = df_test.groupby(\"group_id\").cumcount() + 1\n", - " df_test[\"idx\"] = \"spec\" + df_test[\"group_id\"].astype(int).astype(str) + \"_\" + df_test[\"num\"].astype(str)\n", + " df_test[\"idx\"] = (\n", + " \"spec\"\n", + " + df_test[\"group_id\"].astype(int).astype(str)\n", + " + \"_\"\n", + " + df_test[\"num\"].astype(str)\n", + " )\n", "\n", - " \n", - " #df_train[\"dataset_label\"] = \"df_test\"\n", - " label_map = {\"idx\": \"spec\", \"Name\": \"name\", \"Precursor_type\": \"ionization\", \"SMILES\": \"smiles\", \"InChIKey\": \"inchikey\"}\n", + " # df_train[\"dataset_label\"] = \"df_test\"\n", + " label_map = {\n", + " \"idx\": \"spec\",\n", + " \"Name\": \"name\",\n", + " \"Precursor_type\": \"ionization\",\n", + " \"SMILES\": \"smiles\",\n", + " \"InChIKey\": \"inchikey\",\n", + " }\n", " df_test[\"formula\"] = df_test[\"Metabolite\"].apply(lambda x: x.Formula)\n", " df_test[\"InChIKey\"] = df_test[\"Metabolite\"].apply(lambda x: x.InChIKey)\n", - " #import fiora.IO.mspredWriter as mspredWriter WRITER bugged?\n", - " #mspredWriter.write_labels(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], f\"{home}/data/metabolites/ms-pred/df_test.tsv\", label_map=label_map)\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"].rename(columns=label_map)[[\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]].to_csv(f\"{home}/data/metabolites/ms-pred/df_test.tsv\", index=False, sep=\"\\t\")\n", - "\n", - "if False: # MSnLib\n", + " # import fiora.IO.mspredWriter as mspredWriter WRITER bugged?\n", + " # mspredWriter.write_labels(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], f\"{home}/data/metabolites/ms-pred/df_test.tsv\", label_map=label_map)\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"].rename(columns=label_map)[\n", + " [\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]\n", + " ].to_csv(f\"{home}/data/metabolites/ms-pred/df_test.tsv\", index=False, sep=\"\\t\")\n", "\n", + "if False: # MSnLib\n", " df_msnlib_test[\"num\"] = df_msnlib_test.groupby(\"group_id\").cumcount() + 1\n", - " df_msnlib_test[\"idx\"] = \"spec\" + df_msnlib_test[\"group_id\"].astype(int).astype(str) + \"_\" + df_msnlib_test[\"num\"].astype(str)\n", + " df_msnlib_test[\"idx\"] = (\n", + " \"spec\"\n", + " + df_msnlib_test[\"group_id\"].astype(int).astype(str)\n", + " + \"_\"\n", + " + df_msnlib_test[\"num\"].astype(str)\n", + " )\n", + "\n", + " # df_train[\"dataset_label\"] = \"df_msnlib_test\"\n", + " label_map = {\n", + " \"idx\": \"spec\",\n", + " \"NAME\": \"name\",\n", + " \"Precursor_type\": \"ionization\",\n", + " \"SMILES\": \"smiles\",\n", + " \"INCHIAUX\": \"inchikey\",\n", + " }\n", "\n", - " \n", - " #df_train[\"dataset_label\"] = \"df_msnlib_test\"\n", - " label_map = {\"idx\": \"spec\", \"NAME\": \"name\", \"Precursor_type\": \"ionization\", \"SMILES\": \"smiles\", \"INCHIAUX\": \"inchikey\"}\n", - " \n", " df_msnlib_test[\"formula\"] = df_msnlib_test[\"Metabolite\"].apply(lambda x: x.Formula)\n", - " df_msnlib_test[\"InChIKey\"] = df_msnlib_test[\"Metabolite\"].apply(lambda x: x.InChIKey)\n", - " #import fiora.IO.mspredWriter as mspredWriter WRITER bugged?\n", - " #mspredWriter.write_labels(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], f\"{home}/data/metabolites/ms-pred/df_test.tsv\", label_map=label_map)\n", - " \n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"].rename(columns=label_map)[[\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]].to_csv(f\"{home}/data/metabolites/ms-pred/df_msnlib_test.tsv\", index=False, sep=\"\\t\")\n", + " df_msnlib_test[\"InChIKey\"] = df_msnlib_test[\"Metabolite\"].apply(\n", + " lambda x: x.InChIKey\n", + " )\n", + " # import fiora.IO.mspredWriter as mspredWriter WRITER bugged?\n", + " # mspredWriter.write_labels(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], f\"{home}/data/metabolites/ms-pred/df_test.tsv\", label_map=label_map)\n", + "\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"].rename(\n", + " columns=label_map\n", + " )[\n", + " [\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]\n", + " ].to_csv(\n", + " f\"{home}/data/metabolites/ms-pred/df_msnlib_test.tsv\", index=False, sep=\"\\t\"\n", + " )\n", "\n", "if False:\n", " # ## CASMI\n", " # # prepare output for ICEBERG and SCARF\n", " # # from rdkit import Chem\n", " # # from rdkit.Chem import rdMolDescriptors\n", - " label_map = {\"idx\": \"spec\", \"Precursor_type\": \"ionization\", \"SMILES\": \"smiles\", \"InChIKey\": \"inchikey\"}\n", - " df_cas22[\"idx\"] = [f\"spec{i}\" for i,_ in df_cas22.iterrows()]\n", + " label_map = {\n", + " \"idx\": \"spec\",\n", + " \"Precursor_type\": \"ionization\",\n", + " \"SMILES\": \"smiles\",\n", + " \"InChIKey\": \"inchikey\",\n", + " }\n", + " df_cas22[\"idx\"] = [f\"spec{i}\" for i, _ in df_cas22.iterrows()]\n", " df_cas22[\"name\"] = \"Unknown\"\n", " df_cas22[\"InChIKey\"] = df_cas22[\"Metabolite\"].apply(lambda x: x.InChIKey)\n", " df_cas22[\"formula\"] = df_cas22[\"Metabolite\"].apply(lambda x: x.Formula)\n", - " \n", + "\n", " output_file = f\"{home}/data/metabolites/ms-pred/casmi22_positive_labels.tsv\"\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"].rename(columns=label_map)[[\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\",\t\"smiles\", \"inchikey\"]].to_csv(output_file, index=False, sep=\"\\t\")\n", - " \n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"].rename(columns=label_map)[\n", + " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", + " ].to_csv(output_file, index=False, sep=\"\\t\")\n", + "\n", " # # ### CASMI-16 labels were generated in Casmi16 loader. Avoiding repeat here.\n", "\n", - " \n", - " \n", " # df_test[\"MOL\"] = df_test[\"SMILES\"].apply(Chem.MolFromSmiles)\n", " # df_test[\"formula\"] = df_test[\"MOL\"].apply(rdMolDescriptors.CalcMolFormula)\n", " # df_test[\"dataset\"] = \"df_test\"\n", " # df_test = df_test.rename(columns={\"FILE\": \"spec\", \"ChallengeName\": \"name\", \"Precursor_type\": \"ionization\", \"SMILES\": \"smiles\", \"INCHIKEY\": \"inchikey\"})\n", - " \n", + "\n", " # output_file = f\"{home}/data/metabolites/ms-pred/casmi16_positive_labels.tsv\"\n", - " #df[df[\"ionization\"] == \"[M+H]+\"][[\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\",\t\"smiles\", \"inchikey\"]].to_csv(output_file, index=False, sep=\"\\t\")\n", + " # df[df[\"ionization\"] == \"[M+H]+\"][[\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\",\t\"smiles\", \"inchikey\"]].to_csv(output_file, index=False, sep=\"\\t\")\n", "\n", "\n", "##### SAVE TRAINING SET IN MS_PRED FORMAT\n", @@ -16123,7 +16397,7 @@ "\n", "# # In order to rerun without errors, select only spec that doesn't cause an error. (requires running magma with the whole dataset first)\n", "# working = set([s.split(\".\")[0] for s in os.listdir(d)])\n", - "#mspredWriter.write_dataset(df_train[df_train[\"idx\"].apply(lambda x: x in working)], f\"{home}/repos/ms-pred/data/spec_datasets/df_train/\", label_map=header_map)\n", + "# mspredWriter.write_dataset(df_train[df_train[\"idx\"].apply(lambda x: x in working)], f\"{home}/repos/ms-pred/data/spec_datasets/df_train/\", label_map=header_map)\n", "\n", "# # Create split.tsv\n", "# df_split = df_train[[\"idx\", \"dataset\"]].copy() #df_train[df_train[\"idx\"].apply(lambda x: x in working)][[\"idx\", \"dataset\"]].copy()\n", @@ -16132,8 +16406,7 @@ "# df_split[\"Fold_0\"] = df_split[\"Fold_0\"].str.replace(\"training\", \"train\")\n", "# df_split.to_csv(f\"{home}/repos/ms-pred/data/spec_datasets/df_train/splits/split1.tsv\", index=False, sep=\"\\t\")\n", "\n", - "# Dataframes to ms-pred input format\n", - "\n" + "# Dataframes to ms-pred input format" ] }, { @@ -16157,13 +16430,10 @@ "# df_mnlib_test = 238m (with demo notebook)\n", "\n", "\n", - "\n", - "\n", - "\n", "# Rerunning EVERYTHING AGAIN (avoiding crashes this time, fingers crossed)\n", - "# \n", + "#\n", "# time . data_scripts/y_all_assign_subform.sh: 2471m28,695s\n", - "# time . data_scripts/dag/y_run_magma.sh: 1810m9,266s \n", + "# time . data_scripts/dag/y_run_magma.sh: 1810m9,266s\n", "# 64222 spec\n", "#\n", "# time bash run_scripts/dag_model/01_run_dag_gen_train.sh: 1151m8,021s\n", @@ -16176,8 +16446,7 @@ "# Casmi 16: 1m40,373s\n", "# All: 80m48,594s (crashed no attribute max fragment depth)\n", "# df_test extra: 459.5min\n", - "# df_msnlib_test: \n", - "\n", + "# df_msnlib_test:\n", "\n", "\n", "# raise KeyboardInterrupt(\"HALT\")\n", @@ -16188,10 +16457,10 @@ "# 01 dag train: 66117.73s\n", "# 02 sweep: 0\n", "# 03 dag_gen_predict: 405m21,170s (with crash)\n", - "# 04 \n", + "# 04\n", "\n", "\n", - "# v1.0.0 doesn't allow proper retraining, many scripts missing. \n", + "# v1.0.0 doesn't allow proper retraining, many scripts missing.\n", "# -> Running on commit: update scarf figs to \\pm SEM, instead of confidence -> f56d601e843351e4bdfe0e173d65e8105a1\n", "# run magma took: 347m20,902s\n", "# run subformula: 620m54,021s\n", @@ -16201,7 +16470,7 @@ "\n", "\n", "# Just positives CASMI16. ICEBERG took 98.5s (--gpu threw error though)\n", - "#raise KeyboardInterrupt()\n" + "# raise KeyboardInterrupt()\n" ] }, { @@ -16220,13 +16489,13 @@ "# xxx = []\n", "# for i,d in df_cas.iterrows():\n", "# m = d[\"Metabolite\"]\n", - " \n", + "\n", "# for x,D in icy.iterrows():\n", "# M = D[\"Metabolite\"]\n", "# if (m == M):\n", "# iii += [i]\n", - "# xxx += [x] \n", - "# iii = np.unique(iii) \n", + "# xxx += [x]\n", + "# iii = np.unique(iii)\n", "# print(f\"Found {len(iii)} instances violating test/train split. Metabolite found in train/val set.\") # out of 123\n", "# print(f\"Dropping {len(xxx)} spectra from training DataFrame.\")" ] @@ -16237,22 +16506,41 @@ "metadata": {}, "outputs": [], "source": [ - "import fiora.IO.mspredReader as mspredReader \n", + "import fiora.IO.mspredReader as mspredReader\n", + "\n", "iceberg_dir = f\"{home}/repos/ms-pred/results/test_out_recovery/casmi16/tree_preds_inten\"\n", "df_ice = mspredReader.read(iceberg_dir)\n", "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", "\n", "\n", - "df_cas = pd.merge(df_cas, df_ice[[\"ice_name\", \"ice_peaks\"]], left_on='ChallengeName', right_on='ice_name', how='left')\n", + "df_cas = pd.merge(\n", + " df_cas,\n", + " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", + " left_on=\"ChallengeName\",\n", + " right_on=\"ice_name\",\n", + " how=\"left\",\n", + ")\n", "\n", - "df_cas[[\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]] = np.nan\n", + "df_cas[\n", + " [\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]\n", + "] = np.nan\n", "for i, data in df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", - "\n", " df_cas.at[i, \"ice_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"])\n", - " df_cas.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", - " df_cas.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt, remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_cas.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", - " df_cas.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])\n" + " df_cas.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", + " df_cas.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " data[\"peaks\"],\n", + " data[\"ice_peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_cas.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", + " df_cas.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])" ] }, { @@ -16261,13 +16549,22 @@ "metadata": {}, "outputs": [], "source": [ - "import fiora.IO.mspredReader as mspredReader \n", - "iceberg_dir = f\"{home}/repos/ms-pred/results/test_out_recovery/casmi16t/tree_preds_inten\"\n", + "import fiora.IO.mspredReader as mspredReader\n", + "\n", + "iceberg_dir = (\n", + " f\"{home}/repos/ms-pred/results/test_out_recovery/casmi16t/tree_preds_inten\"\n", + ")\n", "df_ice = mspredReader.read(iceberg_dir)\n", "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", "\n", "\n", - "df_cast = pd.merge(df_cast, df_ice[[\"ice_name\", \"ice_peaks\"]], left_on='ChallengeName', right_on='ice_name', how='left')" + "df_cast = pd.merge(\n", + " df_cast,\n", + " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", + " left_on=\"ChallengeName\",\n", + " right_on=\"ice_name\",\n", + " how=\"left\",\n", + ")" ] }, { @@ -16276,13 +16573,25 @@ "metadata": {}, "outputs": [], "source": [ - "df_cast[[\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\" \"ice_refl_cosine\"]] = np.nan\n", + "df_cast[[\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_precice_refl_cosine\"]] = (\n", + " np.nan\n", + ")\n", "for i, data in df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", - "\n", " df_cast.at[i, \"ice_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"])\n", - " df_cast.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", - " df_cast.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt, remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_cast.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", + " df_cast.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", + " df_cast.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " data[\"peaks\"],\n", + " data[\"ice_peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_cast.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", " df_cast.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])" ] }, @@ -16296,8 +16605,14 @@ "df_ice = mspredReader.read(iceberg_dir)\n", "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", "\n", - "df_cas22[\"idx\"] = [f\"spec{i}\" for i,_ in df_cas22.iterrows()]\n", - "df_cas22 = pd.merge(df_cas22, df_ice[[\"ice_name\", \"ice_peaks\"]], left_on='idx', right_on='ice_name', how='left')" + "df_cas22[\"idx\"] = [f\"spec{i}\" for i, _ in df_cas22.iterrows()]\n", + "df_cas22 = pd.merge(\n", + " df_cas22,\n", + " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", + " left_on=\"idx\",\n", + " right_on=\"ice_name\",\n", + " how=\"left\",\n", + ")" ] }, { @@ -16306,16 +16621,27 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "\n", - "df_cas22[[\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]] = np.nan\n", + "df_cas22[\n", + " [\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]\n", + "] = np.nan\n", "for i, data in df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", - " #print(i, data[\"ice_peaks\"], data[\"ice_peaks\"] is not np.nan)\n", + " # print(i, data[\"ice_peaks\"], data[\"ice_peaks\"] is not np.nan)\n", " if data[\"ice_peaks\"] is not np.nan:\n", " df_cas22.at[i, \"ice_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"])\n", - " df_cas22.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", - " df_cas22.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt, remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_cas22.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", + " df_cas22.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", + " df_cas22.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " data[\"peaks\"],\n", + " data[\"ice_peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_cas22.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", " df_cas22.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])" ] }, @@ -16417,21 +16743,46 @@ "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", "\n", "df_test[\"num\"] = df_test.groupby(\"group_id\").cumcount() + 1\n", - "df_test[\"idx\"] = \"spec\" + df_test[\"group_id\"].astype(int).astype(str) + \"_\" + df_test[\"num\"].astype(str) #df_test[\"idx\"] = [f\"spec{i}\" for i,_ in df_test.iterrows()]\n", + "df_test[\"idx\"] = (\n", + " \"spec\"\n", + " + df_test[\"group_id\"].astype(int).astype(str)\n", + " + \"_\"\n", + " + df_test[\"num\"].astype(str)\n", + ") # df_test[\"idx\"] = [f\"spec{i}\" for i,_ in df_test.iterrows()]\n", "ori_idx = df_test.index.copy()\n", - "df_test = pd.merge(df_test, df_ice[[\"ice_name\", \"ice_peaks\"]], left_on='idx', right_on='ice_name', how='left')\n", + "df_test = pd.merge(\n", + " df_test,\n", + " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", + " left_on=\"idx\",\n", + " right_on=\"ice_name\",\n", + " how=\"left\",\n", + ")\n", "df_test.index = ori_idx\n", - "#df_test.index = df_test[\"idx\"].str.extract(r'spec(\\d+)', expand=False).astype(int) TODO CHECK what happens to the index\n", + "# df_test.index = df_test[\"idx\"].str.extract(r'spec(\\d+)', expand=False).astype(int) TODO CHECK what happens to the index\n", "\n", - "df_test[[\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]] = np.nan\n", + "df_test[\n", + " [\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]\n", + "] = np.nan\n", "for i, data in df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", " try:\n", " df_test.at[i, \"ice_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"])\n", - " df_test.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", - " df_test.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt, remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_test.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", + " df_test.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", + " df_test.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " data[\"peaks\"],\n", + " data[\"ice_peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_test.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", " df_test.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])\n", - " except: pass" + " except:\n", + " pass" ] }, { @@ -16440,26 +16791,57 @@ "metadata": {}, "outputs": [], "source": [ - "iceberg_dir = f\"{home}/repos/ms-pred/results/test_out_recovery/df_msnlib_test/tree_preds_inten\"\n", + "iceberg_dir = (\n", + " f\"{home}/repos/ms-pred/results/test_out_recovery/df_msnlib_test/tree_preds_inten\"\n", + ")\n", "df_ice = mspredReader.read(iceberg_dir)\n", "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", "\n", "df_msnlib_test[\"num\"] = df_msnlib_test.groupby(\"group_id\").cumcount() + 1\n", - "df_msnlib_test[\"idx\"] = \"spec\" + df_msnlib_test[\"group_id\"].astype(int).astype(str) + \"_\" + df_msnlib_test[\"num\"].astype(str) #df_msnlib_test[\"idx\"] = [f\"spec{i}\" for i,_ in df_msnlib_test.iterrows()]\n", + "df_msnlib_test[\"idx\"] = (\n", + " \"spec\"\n", + " + df_msnlib_test[\"group_id\"].astype(int).astype(str)\n", + " + \"_\"\n", + " + df_msnlib_test[\"num\"].astype(str)\n", + ") # df_msnlib_test[\"idx\"] = [f\"spec{i}\" for i,_ in df_msnlib_test.iterrows()]\n", "ori_idx = df_msnlib_test.index.copy()\n", - "df_msnlib_test = pd.merge(df_msnlib_test, df_ice[[\"ice_name\", \"ice_peaks\"]], left_on='idx', right_on='ice_name', how='left')\n", + "df_msnlib_test = pd.merge(\n", + " df_msnlib_test,\n", + " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", + " left_on=\"idx\",\n", + " right_on=\"ice_name\",\n", + " how=\"left\",\n", + ")\n", "df_msnlib_test.index = ori_idx\n", - "#df_msnlib_test.index = df_msnlib_test[\"idx\"].str.extract(r'spec(\\d+)', expand=False).astype(int) TODO CHECK what happens to the index\n", + "# df_msnlib_test.index = df_msnlib_test[\"idx\"].str.extract(r'spec(\\d+)', expand=False).astype(int) TODO CHECK what happens to the index\n", "\n", - "df_msnlib_test[[\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]] = np.nan\n", + "df_msnlib_test[\n", + " [\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]\n", + "] = np.nan\n", "for i, data in df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", " try:\n", - " df_msnlib_test.at[i, \"ice_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"])\n", - " df_msnlib_test.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", - " df_msnlib_test.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt, remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]))\n", - " df_msnlib_test.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt)\n", - " df_msnlib_test.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])\n", - " except: pass" + " df_msnlib_test.at[i, \"ice_cosine\"] = spectral_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"]\n", + " )\n", + " df_msnlib_test.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", + " df_msnlib_test.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", + " data[\"peaks\"],\n", + " data[\"ice_peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " df_msnlib_test.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", + " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " )\n", + " df_msnlib_test.at[i, \"ice_steins\"] = reweighted_dot(\n", + " data[\"peaks\"], data[\"ice_peaks\"]\n", + " )\n", + " except:\n", + " pass" ] }, { @@ -16493,10 +16875,19 @@ "source": [ "from matplotlib import pyplot as plt\n", "\n", - "fig, ax = plt.subplots(figsize=(6,3.5))\n", + "fig, ax = plt.subplots(figsize=(6, 3.5))\n", "i = 81\n", - "ax = sv.plot_spectrum(df_cas.iloc[i], {\"peaks\": df_cas.iloc[i][\"ice_peaks\"]}, title=df_cas.iloc[i][\"ChallengeName\"] + \" vs ICEBERG pred: \" + df_cas.iloc[i][\"ice_name\"], highlight_matches=False, with_grid=False, ax=ax)\n", - "#ax.spines['bottom'].set_position(('outward', 10))\n", + "ax = sv.plot_spectrum(\n", + " df_cas.iloc[i],\n", + " {\"peaks\": df_cas.iloc[i][\"ice_peaks\"]},\n", + " title=df_cas.iloc[i][\"ChallengeName\"]\n", + " + \" vs ICEBERG pred: \"\n", + " + df_cas.iloc[i][\"ice_name\"],\n", + " highlight_matches=False,\n", + " with_grid=False,\n", + " ax=ax,\n", + ")\n", + "# ax.spines['bottom'].set_position(('outward', 10))\n", "plt.show()" ] }, @@ -16673,11 +17064,94 @@ "score = \"spectral_sqrt_cosine\"\n", "avg_func = np.median\n", "\n", - "fiora_res = {\"model\": \"Fiora\", \"Test+\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score]), \"Test-\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][score]), \"MSnLib+\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)), \"MSnLib-\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)), \"CASMI16+\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score]), \"CASMI16-\":avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score]), \"CASMI22+\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score]), \"CASMI22-\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score])} \n", - "cfm_id = {\"model\": \"CFM-ID 4.4.7\", \"Test+\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"Test-\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"MSnLib+\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"MSnLib-\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"CASMI16+\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"CASMI16-\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"CASMI22+\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")]), \"CASMI22-\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")])} \n", - "ice_res = {\"model\": \"ICEBERG\", \"Test+\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"ice\")].fillna(0.0)), \"MSnLib+\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"ice\")].fillna(0.0)), \"CASMI16+\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"ice\")].fillna(0.0)), \"CASMI22+\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"ice\")].fillna(0.0)), \"CASMI22-\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"ice\")].fillna(0.0))} \n", - "\n", - "summaryPos = pd.DataFrame( [fiora_res, cfm_id, ice_res])\n", + "fiora_res = {\n", + " \"model\": \"Fiora\",\n", + " \"Test+\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score]),\n", + " \"Test-\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][score]),\n", + " \"MSnLib+\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " ),\n", + " \"MSnLib-\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " ),\n", + " \"CASMI16+\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score]),\n", + " \"CASMI16-\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score]),\n", + " \"CASMI22+\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score]),\n", + " \"CASMI22-\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score]),\n", + "}\n", + "cfm_id = {\n", + " \"model\": \"CFM-ID 4.4.7\",\n", + " \"Test+\": avg_func(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"Test-\": avg_func(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"MSnLib+\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"MSnLib-\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI16+\": avg_func(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI16-\": avg_func(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI22+\": avg_func(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ]\n", + " ),\n", + " \"CASMI22-\": avg_func(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ]\n", + " ),\n", + "}\n", + "ice_res = {\n", + " \"model\": \"ICEBERG\",\n", + " \"Test+\": avg_func(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"ice\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"MSnLib+\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"ice\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI16+\": avg_func(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"ice\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI22+\": avg_func(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"ice\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI22-\": avg_func(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"ice\")\n", + " ].fillna(0.0)\n", + " ),\n", + "}\n", + "\n", + "summaryPos = pd.DataFrame([fiora_res, cfm_id, ice_res])\n", "print(\"Summary test sets\")\n", "summaryPos" ] @@ -16796,13 +17270,101 @@ "score = \"spectral_sqrt_cosine_wo_prec\"\n", "avg_func = np.median\n", "\n", - "fiora_res = {\"model\": \"Fiora\", \"Test+\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)), \"Test-\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)), \"MSnLib+\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)), \"MSnLib-\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)), \"CASMI16+\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)), \"CASMI16-\":avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)), \"CASMI22+\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)), \"CASMI22-\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0))} \n", + "fiora_res = {\n", + " \"model\": \"Fiora\",\n", + " \"Test+\": avg_func(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " ),\n", + " \"Test-\": avg_func(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " ),\n", + " \"MSnLib+\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " ),\n", + " \"MSnLib-\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " ),\n", + " \"CASMI16+\": avg_func(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " ),\n", + " \"CASMI16-\": avg_func(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " ),\n", + " \"CASMI22+\": avg_func(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " ),\n", + " \"CASMI22-\": avg_func(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " ),\n", + "}\n", "cfm_id = {\n", - " \"model\": \"CFM-ID 4.4.7\", \"Test+\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"Test-\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"MSnLib+\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"MSnLib-\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"CASMI16+\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"CASMI16-\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")].fillna(0.0)), \"CASMI22+\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")]), \"CASMI22-\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score.replace(\"spectral\", \"cfm\")])} \n", + " \"model\": \"CFM-ID 4.4.7\",\n", + " \"Test+\": avg_func(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"Test-\": avg_func(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"MSnLib+\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"MSnLib-\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI16+\": avg_func(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI16-\": avg_func(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI22+\": avg_func(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ]\n", + " ),\n", + " \"CASMI22-\": avg_func(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][\n", + " score.replace(\"spectral\", \"cfm\")\n", + " ]\n", + " ),\n", + "}\n", "ice_res = {\n", - " \"model\": \"ICEBERG\", \"Test+\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"ice\")].fillna(0.0)), \"MSnLib+\": avg_func(df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"ice\")].fillna(0.0)), \"CASMI16+\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"ice\")].fillna(0.0)), \"CASMI22+\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"ice\")].fillna(0.0))} \n", - "\n", - "summaryPos = pd.DataFrame( [fiora_res, cfm_id, ice_res])\n", + " \"model\": \"ICEBERG\",\n", + " \"Test+\": avg_func(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"ice\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"MSnLib+\": avg_func(\n", + " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"ice\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI16+\": avg_func(\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"ice\")\n", + " ].fillna(0.0)\n", + " ),\n", + " \"CASMI22+\": avg_func(\n", + " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", + " score.replace(\"spectral\", \"ice\")\n", + " ].fillna(0.0)\n", + " ),\n", + "}\n", + "\n", + "summaryPos = pd.DataFrame([fiora_res, cfm_id, ice_res])\n", "print(\"Summary test sets - without precursor\")\n", "summaryPos" ] @@ -16831,9 +17393,9 @@ "metadata": {}, "outputs": [], "source": [ - "#CONCAT[[\"Dataset\", \"group_id\", \"ChallengeName\", \"Precursor_type\", \"spectral_sqrt_cosine\", \"ice_sqrt_cosine\", \"cfm_sqrt_cosine\"]].to_excel(f\"{home}/images/paper/T1.xlsx\")\n", - "#CONCAT[[\"Dataset\", \"group_id\", \"ChallengeName\", \"Precursor_type\", \"spectral_sqrt_cosine_wo_prec\", \"ice_sqrt_cosine_wo_prec\", \"cfm_sqrt_cosine_wo_prec\"]].to_excel(f\"{home}/images/paper/ST1.xlsx\")\n", - "#CONCAT[[\"Dataset\", \"group_id\", \"ChallengeName\", \"Precursor_type\", \"spectral_cosine\", \"ice_cosine\", \"cfm_cosine\"]].to_excel(f\"{home}/images/paper/ST2.xlsx\")" + "# CONCAT[[\"Dataset\", \"group_id\", \"ChallengeName\", \"Precursor_type\", \"spectral_sqrt_cosine\", \"ice_sqrt_cosine\", \"cfm_sqrt_cosine\"]].to_excel(f\"{home}/images/paper/T1.xlsx\")\n", + "# CONCAT[[\"Dataset\", \"group_id\", \"ChallengeName\", \"Precursor_type\", \"spectral_sqrt_cosine_wo_prec\", \"ice_sqrt_cosine_wo_prec\", \"cfm_sqrt_cosine_wo_prec\"]].to_excel(f\"{home}/images/paper/ST1.xlsx\")\n", + "# CONCAT[[\"Dataset\", \"group_id\", \"ChallengeName\", \"Precursor_type\", \"spectral_cosine\", \"ice_cosine\", \"cfm_cosine\"]].to_excel(f\"{home}/images/paper/ST2.xlsx\")" ] }, { @@ -16844,48 +17406,92 @@ "source": [ "### Stacked spectra plot\n", "\n", - "def stacked_spectrum(data, text_offset: float=0.0, spec_text_position_override=(0.02, 0.9), split_text_experimental=False, verbose: bool=False):\n", + "\n", + "def stacked_spectrum(\n", + " data,\n", + " text_offset: float = 0.0,\n", + " spec_text_position_override=(0.02, 0.9),\n", + " split_text_experimental=False,\n", + " verbose: bool = False,\n", + "):\n", " name_tags = [\"Fiora\", \"ICEBERG\", \"CFM-ID\"]\n", " peak_tags = [\"sim_peaks\", \"ice_peaks\", \"cfm_peaks\"]\n", - " scores = [(\"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\"), (\"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\"), (\"cfm_sqrt_cosine\", \"cfm_sqrt_cosine_wo_prec\")]\n", - " peak_colors = [\"#0080FF\", \"#FF3333\", \"#FFCC00\"]#sns.color_palette(\"YlOrBr\", 10)[3]]\n", + " scores = [\n", + " (\"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\"),\n", + " (\"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\"),\n", + " (\"cfm_sqrt_cosine\", \"cfm_sqrt_cosine_wo_prec\"),\n", + " ]\n", + " peak_colors = [\n", + " \"#0080FF\",\n", + " \"#FF3333\",\n", + " \"#FFCC00\",\n", + " ] # sns.color_palette(\"YlOrBr\", 10)[3]]\n", " spec_height, spec_width = 1.5, 8\n", - " \n", - " fig, axs = plt.subplots(len(peak_tags) + 2, 1, figsize=(spec_width, spec_height * (len(peak_tags) + 2)), sharex=True) # gridspec_kw={'width_ratios': [1, 3]}\n", + "\n", + " fig, axs = plt.subplots(\n", + " len(peak_tags) + 2,\n", + " 1,\n", + " figsize=(spec_width, spec_height * (len(peak_tags) + 2)),\n", + " sharex=True,\n", + " ) # gridspec_kw={'width_ratios': [1, 3]}\n", " img = data[\"Metabolite\"].draw(ax=axs[0])\n", - " textstr=f'Name: {data[\"Name\"] if \"Name\" in data.keys() else data[\"NAME\"] if \"NAME\" in data.keys() else data[\"ChallengeName\"]}\\nPrecursor type: {data[\"Precursor_type\"]}\\nCollision energy: {data[\"CE\"]:.1f} eV'\n", - " fig.text(.54 + text_offset, .8, textstr, #transform=axs[0].transAxes,\n", - " fontsize=8, horizontalalignment=\"left\", verticalalignment='top', bbox=dict(facecolor='white', alpha=0))\n", - " \n", - " plt.subplots_adjust(hspace=0.12) #(top=0.94, bottom=0.12, right=0.97, left=0.08)\n", + " textstr = f\"Name: {data['Name'] if 'Name' in data.keys() else data['NAME'] if 'NAME' in data.keys() else data['ChallengeName']}\\nPrecursor type: {data['Precursor_type']}\\nCollision energy: {data['CE']:.1f} eV\"\n", + " fig.text(\n", + " 0.54 + text_offset,\n", + " 0.8,\n", + " textstr, # transform=axs[0].transAxes,\n", + " fontsize=8,\n", + " horizontalalignment=\"left\",\n", + " verticalalignment=\"top\",\n", + " bbox=dict(facecolor=\"white\", alpha=0),\n", + " )\n", + "\n", + " plt.subplots_adjust(hspace=0.12) # (top=0.94, bottom=0.12, right=0.97, left=0.08)\n", " sv.plot_spectrum(data, highlight_matches=False, ppm_tolerance=200, ax=axs[1])\n", - " \n", - " axs[1].text(spec_text_position_override[0], spec_text_position_override[1], \"Experimental\\nspectrum\" if split_text_experimental else \"Experimental spectrum\", transform=axs[1].transAxes,\n", - " fontsize=11, verticalalignment='top')\n", + "\n", + " axs[1].text(\n", + " spec_text_position_override[0],\n", + " spec_text_position_override[1],\n", + " \"Experimental\\nspectrum\"\n", + " if split_text_experimental\n", + " else \"Experimental spectrum\",\n", + " transform=axs[1].transAxes,\n", + " fontsize=11,\n", + " verticalalignment=\"top\",\n", + " )\n", " axs[1].set_xlabel(\"\")\n", " for i, tag in enumerate(peak_tags):\n", - " try: \n", + " try:\n", " ax = axs[i + 2]\n", " sv.plot_spectrum({\"peaks\": data[tag]}, ax=ax, color=peak_colors[i])\n", - " #ax.legend(title=name_tags[i], loc=\"upper left\", labels=scores[i])\n", - " textstr = '\\n'.join((\n", - " f'$\\\\bf{{{name_tags[i]}}}$',\n", - " f'Cosine: {data[scores[i][0]]:.2f}',\n", - " f'w/o prec: {data[scores[i][1]]:.2f}'))\n", + " # ax.legend(title=name_tags[i], loc=\"upper left\", labels=scores[i])\n", + " textstr = \"\\n\".join(\n", + " (\n", + " f\"$\\\\bf{{{name_tags[i]}}}$\",\n", + " f\"Cosine: {data[scores[i][0]]:.2f}\",\n", + " f\"w/o prec: {data[scores[i][1]]:.2f}\",\n", + " )\n", + " )\n", "\n", " ax.set_xlabel(\"\")\n", - " ax.text(spec_text_position_override[0], spec_text_position_override[1], textstr, transform=ax.transAxes,\n", - " fontsize=11, verticalalignment='top')#,\n", - " #bbox=dict(boxstyle='square,pad=0.5', facecolor='white', alpha=0.5))\n", - " except: \n", + " ax.text(\n", + " spec_text_position_override[0],\n", + " spec_text_position_override[1],\n", + " textstr,\n", + " transform=ax.transAxes,\n", + " fontsize=11,\n", + " verticalalignment=\"top\",\n", + " ) # ,\n", + " # bbox=dict(boxstyle='square,pad=0.5', facecolor='white', alpha=0.5))\n", + " except:\n", " if verbose:\n", " print(f\"Could not plot spectrum {data[tag]}\")\n", " continue\n", " axs[-1].set_xlabel(\"m/z\")\n", - " \n", + "\n", " sv.set_default_peak_color(\"#212121\")\n", "\n", - " return fig, axs\n" + " return fig, axs" ] }, { @@ -16905,13 +17511,13 @@ } ], "source": [ - "#zzz = df_cas[df_cas[\"spectral_sqrt_cosine_wo_prec\"].isna()]\n", + "# zzz = df_cas[df_cas[\"spectral_sqrt_cosine_wo_prec\"].isna()]\n", "# Interesting spectra to look at with drastic changes from precursor removal-\n", "# 159, 93, 182, 124, 134, 205! (CASMI)\n", "# 6369! 13990 33153 62382! (Test split)\n", "\n", - "#img = stacked_spectrum(df_cas.loc[93], text_offset=0.05) # TODO wrong index, since reset_index was introduced above\n", - "#plt.show()\n", + "# img = stacked_spectrum(df_cas.loc[93], text_offset=0.05) # TODO wrong index, since reset_index was introduced above\n", + "# plt.show()\n", "\n", "img = stacked_spectrum(df_test.loc[62382])\n", "plt.show()" @@ -16923,9 +17529,13 @@ "metadata": {}, "outputs": [], "source": [ - "def get_average_spectra(df: pd.DataFrame, score: str=\"spectral_sqrt_cosine\", dif: float=0.05):\n", + "def get_average_spectra(\n", + " df: pd.DataFrame, score: str = \"spectral_sqrt_cosine\", dif: float = 0.05\n", + "):\n", " median_cos, median_cos_wo_prec = df[score].median(), df[score + \"_wo_prec\"].median()\n", - " filter_average = (abs(df[score] - median_cos) < dif) & (abs(df[score + \"_wo_prec\"] - median_cos_wo_prec) < dif)\n", + " filter_average = (abs(df[score] - median_cos) < dif) & (\n", + " abs(df[score + \"_wo_prec\"] - median_cos_wo_prec) < dif\n", + " )\n", "\n", " return df[filter_average]" ] @@ -16939,7 +17549,10 @@ "df_a = get_average_spectra(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], dif=0.05)\n", "df_a = get_average_spectra(df_a, score=\"ice_sqrt_cosine\", dif=0.05)\n", "df_a = get_average_spectra(df_a, score=\"cfm_sqrt_cosine\", dif=0.05)\n", - "df_a = df_a[(df_a[\"spectral_sqrt_cosine_wo_prec\"] > df_a[\"ice_sqrt_cosine_wo_prec\"]) & (df_a[\"spectral_sqrt_cosine_wo_prec\"] > df_a[\"cfm_sqrt_cosine_wo_prec\"])]" + "df_a = df_a[\n", + " (df_a[\"spectral_sqrt_cosine_wo_prec\"] > df_a[\"ice_sqrt_cosine_wo_prec\"])\n", + " & (df_a[\"spectral_sqrt_cosine_wo_prec\"] > df_a[\"cfm_sqrt_cosine_wo_prec\"])\n", + "]" ] }, { @@ -16959,7 +17572,7 @@ } ], "source": [ - "#df_a = df_a[df_a[\"lib\"] == \"MSDIAL\"]\n", + "# df_a = df_a[df_a[\"lib\"] == \"MSDIAL\"]\n", "df_a.shape" ] }, @@ -17001,15 +17614,14 @@ "reset_matplotlib()\n", "for i, data in df_a.head(1).iterrows():\n", " print(data)\n", - " fig, axs = stacked_spectrum(df_a.loc[i], text_offset=0.02, split_text_experimental=True, verbose=True)\n", + " fig, axs = stacked_spectrum(\n", + " df_a.loc[i], text_offset=0.02, split_text_experimental=True, verbose=True\n", + " )\n", " mol = df_a.loc[i][\"Metabolite\"].MOL\n", " # fig.savefig(f\"{home}/images/paper/stacked_avg{i}.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/stacked_avg{i}.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/stacked_avg{i}.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - " plt.show()\n", - " \n", - " \n", - " " + " plt.show()" ] }, { @@ -17027,7 +17639,7 @@ "# \"ICEBERG intensity\": df_a.iloc[0][\"ice_peaks\"][\"intensity\"],\n", "# \"CFM-ID m/z\": df_a.iloc[0][\"cfm_peaks\"][\"mz\"],\n", "# \"CFM-ID intensity\": df_a.iloc[0][\"cfm_peaks\"][\"intensity\"],\n", - "# } \n", + "# }\n", "# spec_df = pd.DataFrame.from_dict(new_df, orient=\"index\").transpose()\n", "# spec_df = spec_df.fillna(\"-\")" ] @@ -17038,7 +17650,7 @@ "metadata": {}, "outputs": [], "source": [ - "#spec_df.to_excel(f\"{home}/images/paper/F7.xlsx\")" + "# spec_df.to_excel(f\"{home}/images/paper/F7.xlsx\")" ] }, { @@ -17065,11 +17677,17 @@ "df_ex = df_msnlib_test[df_msnlib_test[\"SPECTYPE\"] == \"SAME_ENERGY\"]\n", "\n", "\n", - "high_score_30 = df_ex[(df_ex[\"CE\"] == 30) & (df_ex[\"spectral_sqrt_cosine\"] >= high_score_threshold)]\n", - "low_score_60 = df_ex[(df_ex[\"CE\"] == 60) & (df_ex[\"spectral_sqrt_cosine\"] <= low_score_threshold)]\n", + "high_score_30 = df_ex[\n", + " (df_ex[\"CE\"] == 30) & (df_ex[\"spectral_sqrt_cosine\"] >= high_score_threshold)\n", + "]\n", + "low_score_60 = df_ex[\n", + " (df_ex[\"CE\"] == 60) & (df_ex[\"spectral_sqrt_cosine\"] <= low_score_threshold)\n", + "]\n", "\n", - "common_group_ids = pd.merge(high_score_30[['group_id']], low_score_60[['group_id']], on='group_id')\n", - "matching_ids = common_group_ids['group_id'].unique()\n" + "common_group_ids = pd.merge(\n", + " high_score_30[[\"group_id\"]], low_score_60[[\"group_id\"]], on=\"group_id\"\n", + ")\n", + "matching_ids = common_group_ids[\"group_id\"].unique()" ] }, { @@ -17140,8 +17758,10 @@ "source": [ "reset_matplotlib()\n", "for i, data in examples.head(3).iterrows():\n", - " #print(data)\n", - " fig, axs = stacked_spectrum(examples.loc[i], text_offset=0.02, split_text_experimental=True, verbose=True)\n", + " # print(data)\n", + " fig, axs = stacked_spectrum(\n", + " examples.loc[i], text_offset=0.02, split_text_experimental=True, verbose=True\n", + " )\n", " ce = data[\"CE\"]\n", " mol = data[\"Metabolite\"].MOL\n", " ce = data[\"CE\"]\n", @@ -17198,16 +17818,15 @@ "source": [ "if has_m_plus:\n", " print(\"Conducting performance analysis for [M]+ and [M]- precursors\")\n", - " \n", + "\n", " df_m = df_msnlib_test[df_msnlib_test[\"Precursor_type\"].isin([\"[M]+\", \"[M]-\"])]\n", " df_m = df_m[df_m[\"SPECTYPE\"] != \"ALL_MSN_TO_PSEUDO_MS2\"]\n", - " df_not_m = df_msnlib_test[df_msnlib_test[\"Precursor_type\"].isin([\"[M+H]+\", \"[M-H]-\"])]\n", - " \n", - " \n", - " print(df_m.groupby(\"Precursor_type\")[\"spectral_sqrt_cosine\"].median())\n", - " print(df_m.groupby(\"Precursor_type\")[\"spectral_sqrt_cosine_wo_prec\"].median())\n", + " df_not_m = df_msnlib_test[\n", + " df_msnlib_test[\"Precursor_type\"].isin([\"[M+H]+\", \"[M-H]-\"])\n", + " ]\n", "\n", - " " + " print(df_m.groupby(\"Precursor_type\")[\"spectral_sqrt_cosine\"].median())\n", + " print(df_m.groupby(\"Precursor_type\")[\"spectral_sqrt_cosine_wo_prec\"].median())" ] }, { @@ -17228,7 +17847,7 @@ ], "source": [ "if has_m_plus:\n", - " print(df_m.groupby(\"Precursor_type\")[\"group_id\"].nunique())\n" + " print(df_m.groupby(\"Precursor_type\")[\"group_id\"].nunique())" ] }, { @@ -17281,9 +17900,14 @@ " for i, data2 in df_m.iterrows():\n", " other_metabolite = data2[\"Metabolite\"]\n", " if metabolite == other_metabolite:\n", - " if data[\"CE\"] == data2[\"CE\"] and ((\"+\" in data[\"Precursor_type\"]) == (\"+\" in data2[\"Precursor_type\"])):\n", - " print(\"cos:\", data[\"spectral_sqrt_cosine\"], data2[\"spectral_sqrt_cosine\"])\n", - "\n" + " if data[\"CE\"] == data2[\"CE\"] and (\n", + " (\"+\" in data[\"Precursor_type\"]) == (\"+\" in data2[\"Precursor_type\"])\n", + " ):\n", + " print(\n", + " \"cos:\",\n", + " data[\"spectral_sqrt_cosine\"],\n", + " data2[\"spectral_sqrt_cosine\"],\n", + " )" ] }, { @@ -17333,29 +17957,38 @@ "source": [ "if has_m_plus:\n", " data = df_m.iloc[0]\n", - " \n", + "\n", " for i, data in df_m.tail(2).iterrows():\n", " cosine = data[\"spectral_sqrt_cosine\"]\n", " cosine_wo = data[\"spectral_sqrt_cosine_wo_prec\"]\n", " name = data[\"NAME\"]\n", " prec = data[\"Precursor_type\"]\n", - " \n", + "\n", " print(f\"{prec} ({i}): cosine {cosine:0.2} / {cosine_wo:0.2}\")\n", " print(max(data[\"peaks\"][\"mz\"]))\n", " print(max(data[\"sim_peaks\"][\"mz\"]))\n", "\n", - " fig, axs = plt.subplots(1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - " img = data[\"Metabolite\"].draw(ax= axs[0])\n", - "\n", + " fig, axs = plt.subplots(\n", + " 1,\n", + " 2,\n", + " figsize=(12.8, 4.2),\n", + " gridspec_kw={\"width_ratios\": [1, 3]},\n", + " sharey=False,\n", + " )\n", + " img = data[\"Metabolite\"].draw(ax=axs[0])\n", "\n", - " axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", + " axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " )\n", " axs[0].set_title(data[\"NAME\"])\n", "\n", - " ax = sv.plot_spectrum(data, {\"peaks\": data[\"sim_peaks\"]}, ax=axs[1], highlight_matches=False)\n", + " ax = sv.plot_spectrum(\n", + " data, {\"peaks\": data[\"sim_peaks\"]}, ax=axs[1], highlight_matches=False\n", + " )\n", " # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - " plt.show()\n" + " plt.show()" ] }, { @@ -17409,9 +18042,9 @@ "# ### Uncomment to publish a new OS version\n", "# v = \"0.1.0\" # Enter version number\n", "# model.model_params[\"version_number\"] = v # First OS model to be published\n", - "# model.model_params[\"version\"] = f\"Fiora OS v{v}\" \n", + "# model.model_params[\"version\"] = f\"Fiora OS v{v}\"\n", "# model.model_params[\"training_library\"] = \"MSnLib\" # Add information\n", - "# model.model_params[\"comment\"] = \"This is the first open-source Fiora model released on GitHub trained on the MSnLib v1.0.\" \n", + "# model.model_params[\"comment\"] = \"This is the first open-source Fiora model released on GitHub trained on the MSnLib v1.0.\"\n", "# model.model_params[\"disclaimer\"] = \"No prediction software is perfect. This is an early open-source model. Use with caution.\"" ] }, @@ -17421,7 +18054,7 @@ "metadata": {}, "outputs": [], "source": [ - "# NEW_MODEL_PATH = f\"../models/fiora_OS_v{v}.pt\" \n", + "# NEW_MODEL_PATH = f\"../models/fiora_OS_v{v}.pt\"\n", "# model.save(NEW_MODEL_PATH)" ] }, @@ -17437,7 +18070,7 @@ "# print(zzz.shape)\n", "# for i, data in zzz.iterrows():\n", "# print(i)\n", - "# try: \n", + "# try:\n", "# stacked_spectrum(data)\n", "# plt.show()\n", "# except: continue" @@ -17526,7 +18159,7 @@ } ], "source": [ - "df_test.groupby(\"Precursor_type\").agg(num =(\"group_id\", lambda x: len(x.unique())))" + "df_test.groupby(\"Precursor_type\").agg(num=(\"group_id\", lambda x: len(x.unique())))" ] }, { @@ -17549,7 +18182,7 @@ "# df[\"precursor_mode\"] = df[\"Metabolite\"].apply(lambda x: x.metadata[\"precursor_mode\"])\n", "# ref_lib[\"theoretical_precursor_mz\"] = ref_lib[\"Metabolite\"].apply(lambda x: x.get_theoretical_precursor_mz(ion_type=x.metadata[\"precursor_mode\"]))\n", "# ref_lib[\"precursor_mode\"] = ref_lib[\"Metabolite\"].apply(lambda x: x.metadata[\"precursor_mode\"])\n", - " \n", + "\n", "# # Initialize k_scores column\n", "# df[\"k\"] = [0 for _ in range(len(df))]\n", "# df[\"k_wop\"] = [0 for _ in range(len(df))]\n", @@ -17558,18 +18191,18 @@ "# df[\"k_scores\"] = [[] for _ in range(len(df))]\n", "# df[\"k_scores_wop\"] = [[] for _ in range(len(df))]\n", "# df[\"k_scores_avg\"] = [[] for _ in range(len(df))]\n", - " \n", + "\n", "# for i, row in df.iterrows():\n", "# sim_peaks = row[\"sim_peaks\"]\n", "# sim_mz = row[\"theoretical_precursor_mz\"]\n", "# sim_mode = row[\"precursor_mode\"]\n", - " \n", + "\n", "# # Filter reference library based on matching mode and precursor mz\n", "# matching_refs = ref_lib[\n", "# (ref_lib[\"precursor_mode\"] == sim_mode) &\n", "# (abs(ref_lib[\"theoretical_precursor_mz\"] - sim_mz) < DEFAULT_DALTON)\n", "# ]\n", - " \n", + "\n", "# scores, scores_wop, scores_avg = [], [], []\n", "\n", "# for _, ref_row in matching_refs.iterrows():\n", @@ -17579,12 +18212,12 @@ "# score_wop = spectral_cosine(sim_peaks, ref_peaks, transform=np.sqrt, remove_mz=sim_mz)\n", "# scores_wop.append(score_wop)\n", "# scores_avg.append((score + score_wop) / 2.0)\n", - " \n", + "\n", "# # Keep top k scores\n", "# df.at[i, \"k_scores\"] = sorted(scores, reverse=True)[:k]\n", "# df.at[i, \"k_scores_wop\"] = sorted(scores_wop, reverse=True)[:k]\n", "# df.at[i, \"k_scores_avg\"] = sorted(scores_avg, reverse=True)[:k]\n", - " \n", + "\n", "# score = row[\"spectral_sqrt_cosine\"]\n", "# score_wop = row[\"spectral_sqrt_cosine_wo_prec\"]\n", "# score_avg = (score + score_wop) / 2.0\n", @@ -17594,7 +18227,7 @@ "# # Determine position k for each score\n", "# df.at[i, \"k\"] = sum(1 for s in df.at[i, \"k_scores\"] if score <= s) + 1\n", "# df.at[i, \"k_wop\"] = sum(1 for s in df.at[i, \"k_scores_wop\"] if score_wop <= s) + 1\n", - "# df.at[i, \"k_avg\"] = sum(1 for s in df.at[i, \"k_scores_avg\"] if score_avg <= s) + 1 \n" + "# df.at[i, \"k_avg\"] = sum(1 for s in df.at[i, \"k_scores_avg\"] if score_avg <= s) + 1\n" ] }, { @@ -17606,21 +18239,26 @@ "import requests\n", "import pubchempy\n", "\n", - "def retrieve_first_k_compounds_with_mass(mol_weight: float, tolerance: float=DEFAULT_DALTON, k: int=50):\n", + "\n", + "def retrieve_first_k_compounds_with_mass(\n", + " mol_weight: float, tolerance: float = DEFAULT_DALTON, k: int = 50\n", + "):\n", " # Construct the URL for the molecular weight range\n", " lower_bound = mol_weight - tolerance\n", " upper_bound = mol_weight + tolerance\n", " url = f\"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/exact_mass/range/{lower_bound}/{upper_bound}/cids/JSON\"\n", " # TODO test molecular_weight\n", - " \n", + "\n", " # Request compound IDs from PubChem\n", " response = requests.get(url)\n", " if response.ok:\n", " js = response.json()\n", " if \"IdentifierList\" in js and \"CID\" in js[\"IdentifierList\"]:\n", " # if too few entries -> double tolerance\n", - " if len(js[\"IdentifierList\"][\"CID\"]) < k: \n", - " return retrieve_first_k_compounds_with_mass(mol_weight, tolerance * 2, k)\n", + " if len(js[\"IdentifierList\"][\"CID\"]) < k:\n", + " return retrieve_first_k_compounds_with_mass(\n", + " mol_weight, tolerance * 2, k\n", + " )\n", " # Retrieve the first k compounds using PubChemPy\n", " compound_list = pubchempy.get_compounds(js[\"IdentifierList\"][\"CID\"][:k])\n", " return [c.canonical_smiles for c in compound_list]\n", @@ -17629,7 +18267,7 @@ " return retrieve_first_k_compounds_with_mass(mol_weight, tolerance * 2, k)\n", " else:\n", " print(f\"Error: {response.status_code} - {response.text}\")\n", - " return []\n" + " return []" ] }, { @@ -17649,49 +18287,93 @@ " return None\n", "\n", "\n", - "def get_top_k_scores(candidates: List[str], metabolite: Metabolite, exp_peaks: Dict, ce_steps: List[int] = [], k: int = 10) -> Dict[str, float]:\n", + "def get_top_k_scores(\n", + " candidates: List[str],\n", + " metabolite: Metabolite,\n", + " exp_peaks: Dict,\n", + " ce_steps: List[int] = [],\n", + " k: int = 10,\n", + ") -> Dict[str, float]:\n", " candidates_df = pd.DataFrame({\"SMILES\": [c for c in candidates if \".\" not in c]})\n", - " candidates_df[\"peaks\"] = [exp_peaks for _ in range(candidates_df.shape[0])] \n", - " \n", + " candidates_df[\"peaks\"] = [exp_peaks for _ in range(candidates_df.shape[0])]\n", + "\n", " # Perform Metabolite fragmentation workflow\n", - " candidates_df[\"Metabolite\"] = candidates_df[\"SMILES\"].apply(safe_metabolite_creation)\n", + " candidates_df[\"Metabolite\"] = candidates_df[\"SMILES\"].apply(\n", + " safe_metabolite_creation\n", + " )\n", " candidates_df.dropna(subset=[\"Metabolite\"], inplace=True)\n", " eq_metabolite_mask = candidates_df[\"Metabolite\"].apply(lambda m: m == metabolite)\n", " candidates_df = candidates_df[~eq_metabolite_mask]\n", " candidates_df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - " candidates_df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", + " candidates_df[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + " )\n", " metadata = metabolite.metadata\n", - " \n", + "\n", " if not ce_steps:\n", - " candidates_df[\"Metabolite\"].apply(lambda x: x.add_metadata(metadata, covariate_encoder, rt_encoder))\n", + " candidates_df[\"Metabolite\"].apply(\n", + " lambda x: x.add_metadata(metadata, covariate_encoder, rt_encoder)\n", + " )\n", " candidates_df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", " candidates_df = fiora.simulate_all(candidates_df, model)\n", " else:\n", " for i, ce in enumerate(ce_steps):\n", " metadata.update({\"collision_energy\": ce})\n", - " candidates_df[\"Metabolite\"].apply(lambda x: x.add_metadata(metadata, covariate_encoder, rt_encoder))\n", + " candidates_df[\"Metabolite\"].apply(\n", + " lambda x: x.add_metadata(metadata, covariate_encoder, rt_encoder)\n", + " )\n", " candidates_df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - " candidates_df = fiora.simulate_all(candidates_df, model, suffix=f'_{i + 1}')\n", + " candidates_df = fiora.simulate_all(candidates_df, model, suffix=f\"_{i + 1}\")\n", " if len(ce_steps) != 3:\n", - " raise NotImplementedError(\"Only three collision energy steps are implemented\")\n", - " candidates_df[\"merged_peaks\"] = candidates_df.apply(lambda x: merge_annotated_spectrum(merge_annotated_spectrum(x[\"sim_peaks_1\"], x[\"sim_peaks_2\"]), x[\"sim_peaks_3\"]) , axis=1)\n", - " candidates_df[\"spectral_sqrt_cosine\"] = candidates_df.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " candidates_df[\"spectral_sqrt_cosine_wo_prec\"] = candidates_df.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt, remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(x[\"Metabolite\"].metadata[\"precursor_mode\"])), axis=1)\n", - " candidates_df[\"spectral_sqrt_cosine_avg\"] = (candidates_df[\"spectral_sqrt_cosine\"] + candidates_df[\"spectral_sqrt_cosine_wo_prec\"]) / 2.0\n", - " \n", - " score_tags = [\"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"spectral_sqrt_cosine_avg\"]\n", + " raise NotImplementedError(\n", + " \"Only three collision energy steps are implemented\"\n", + " )\n", + " candidates_df[\"merged_peaks\"] = candidates_df.apply(\n", + " lambda x: merge_annotated_spectrum(\n", + " merge_annotated_spectrum(x[\"sim_peaks_1\"], x[\"sim_peaks_2\"]),\n", + " x[\"sim_peaks_3\"],\n", + " ),\n", + " axis=1,\n", + " )\n", + " candidates_df[\"spectral_sqrt_cosine\"] = candidates_df.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + " )\n", + " candidates_df[\"spectral_sqrt_cosine_wo_prec\"] = candidates_df.apply(\n", + " lambda x: spectral_cosine(\n", + " x[\"peaks\"],\n", + " x[\"merged_peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " x[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " ),\n", + " axis=1,\n", + " )\n", + " candidates_df[\"spectral_sqrt_cosine_avg\"] = (\n", + " candidates_df[\"spectral_sqrt_cosine\"]\n", + " + candidates_df[\"spectral_sqrt_cosine_wo_prec\"]\n", + " ) / 2.0\n", + "\n", + " score_tags = [\n", + " \"spectral_sqrt_cosine\",\n", + " \"spectral_sqrt_cosine_wo_prec\",\n", + " \"spectral_sqrt_cosine_avg\",\n", + " ]\n", " scores = {\n", - " tag: candidates_df[tag].sort_values(ascending=False).head(k).values for tag in score_tags\n", + " tag: candidates_df[tag].sort_values(ascending=False).head(k).values\n", + " for tag in score_tags\n", " }\n", " scores[\"eq_c\"] = list(np.where(eq_metabolite_mask)[0])\n", - " \n", + "\n", " return scores\n", "\n", - "def get_k(row, scoring_func: str=\"spectral_sqrt_cosine\"):\n", + "\n", + "def get_k(row, scoring_func: str = \"spectral_sqrt_cosine\"):\n", " score = row[scoring_func]\n", " candidate_scores = row[\"candidate_scores\"][scoring_func]\n", - " epsilon = 0.0 # optional: add epsilon = 0.00001 # for Indistinguishable compounds / rounding deviation\n", - " return sum([(score - c_score) <= epsilon for c_score in candidate_scores]) + 1\n" + " epsilon = 0.0 # optional: add epsilon = 0.00001 # for Indistinguishable compounds / rounding deviation\n", + " return sum([(score - c_score) <= epsilon for c_score in candidate_scores]) + 1" ] }, { @@ -17736,7 +18418,9 @@ "import ast\n", "\n", "# # Load back the candidates into df_test\n", - "df_test[\"candidates\"] = pd.read_csv(f\"{home}/data/metabolites/benchmarking/df_test_candidates.csv\", index_col=0)[\"candidates\"].apply(ast.literal_eval)" + "df_test[\"candidates\"] = pd.read_csv(\n", + " f\"{home}/data/metabolites/benchmarking/df_test_candidates.csv\", index_col=0\n", + ")[\"candidates\"].apply(ast.literal_eval)" ] }, { @@ -17745,7 +18429,10 @@ "metadata": {}, "outputs": [], "source": [ - "df_test[\"candidate_scores\"] = df_test.apply(lambda x: get_top_k_scores(x[\"candidates\"], x[\"Metabolite\"], x[\"peaks\"], k=10) , axis=1)" + "df_test[\"candidate_scores\"] = df_test.apply(\n", + " lambda x: get_top_k_scores(x[\"candidates\"], x[\"Metabolite\"], x[\"peaks\"], k=10),\n", + " axis=1,\n", + ")" ] }, { @@ -17755,8 +18442,12 @@ "outputs": [], "source": [ "df_test[\"k\"] = df_test.apply(get_k, axis=1)\n", - "df_test[\"k_wo_prec\"] = df_test.apply(lambda x: get_k(x, scoring_func=\"spectral_sqrt_cosine_wo_prec\"), axis=1)\n", - "df_test[\"k_avg\"] = df_test.apply(lambda x: get_k(x, scoring_func=\"spectral_sqrt_cosine_avg\"), axis=1)" + "df_test[\"k_wo_prec\"] = df_test.apply(\n", + " lambda x: get_k(x, scoring_func=\"spectral_sqrt_cosine_wo_prec\"), axis=1\n", + ")\n", + "df_test[\"k_avg\"] = df_test.apply(\n", + " lambda x: get_k(x, scoring_func=\"spectral_sqrt_cosine_avg\"), axis=1\n", + ")" ] }, { @@ -17790,45 +18481,66 @@ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", - "def plot_top_k_performance(df: pd.DataFrame, k_tags: list, labels: list=[], colors=[], max_rank: int = 11, ylim: (float, float)=(0,1), ratio=(8, 6), title=\"\"):\n", + "\n", + "def plot_top_k_performance(\n", + " df: pd.DataFrame,\n", + " k_tags: list,\n", + " labels: list = [],\n", + " colors=[],\n", + " max_rank: int = 11,\n", + " ylim: (float, float) = (0, 1),\n", + " ratio=(8, 6),\n", + " title=\"\",\n", + "):\n", " fig, ax = plt.subplots(figsize=ratio)\n", - " #fig.set_size_inches(5, 5)\n", - " #fig.tight_layout()\n", + " # fig.set_size_inches(5, 5)\n", + " # fig.tight_layout()\n", " # Loop over each k_tag and plot its cumulative fraction\n", " for i, k_tag in enumerate(k_tags):\n", " # Get counts for each k and ensure all ranks (1 to max_rank) are included\n", - " counts = df[k_tag].value_counts().reindex(range(1, max_rank + 1), fill_value=0).sort_index()\n", - " \n", + " counts = (\n", + " df[k_tag]\n", + " .value_counts()\n", + " .reindex(range(1, max_rank + 1), fill_value=0)\n", + " .sort_index()\n", + " )\n", + "\n", " # Compute cumulative counts and normalize to fractions\n", " cumulative_counts = counts.cumsum()\n", " total = cumulative_counts.iloc[-1]\n", " fractions = cumulative_counts / total\n", - " \n", + "\n", " # Create a DataFrame for plotting\n", - " cumulative_df = pd.DataFrame({\"Rank (k)\": range(1, max_rank), \"Fraction\": fractions.values[:-1]})\n", - " \n", + " cumulative_df = pd.DataFrame(\n", + " {\"Rank (k)\": range(1, max_rank), \"Fraction\": fractions.values[:-1]}\n", + " )\n", + "\n", " # Plot\n", " sns.pointplot(\n", " data=cumulative_df,\n", " x=\"Rank (k)\",\n", " y=\"Fraction\",\n", " label=labels[i] if len(labels) > 0 else k_tag,\n", - " linestyle=(0, (1,2.5)) if (i+1) % 3 == 0 else \"-\", # Alternate linestyles\n", - " markers=\"x\" if (i+1) % 3 == 0 else \"o\",\n", - " color=f\"C{i}\" if len(colors) == 0 else colors[i], # Use different colors for each tag\n", - " linewidth= 2.5,\n", + " linestyle=(0, (1, 2.5))\n", + " if (i + 1) % 3 == 0\n", + " else \"-\", # Alternate linestyles\n", + " markers=\"x\" if (i + 1) % 3 == 0 else \"o\",\n", + " color=f\"C{i}\"\n", + " if len(colors) == 0\n", + " else colors[i], # Use different colors for each tag\n", + " linewidth=2.5,\n", " )\n", - " \n", - " plt.rc('axes', labelsize=14)\n", - " plt.rc('legend', fontsize=14)\n", - " ax.tick_params(axis='both', which='major', labelsize=13)\n", - " \n", - " #ax.set_aspect(1, adjustable=\"box\")\n", + "\n", + " plt.rc(\"axes\", labelsize=14)\n", + " plt.rc(\"legend\", fontsize=14)\n", + " ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "\n", + " # ax.set_aspect(1, adjustable=\"box\")\n", " # Final plot settings\n", " plt.xlabel(\"Rank (k)\")\n", " plt.ylabel(\"Recall\")\n", " plt.ylim(ylim)\n", - " plt.legend(title=title)#, labels=labels if len(labels) > 0 else k_tags)\n", + " plt.legend(title=title) # , labels=labels if len(labels) > 0 else k_tags)\n", " plt.grid(True)\n", "\n", " return fig, cumulative_df" @@ -17852,8 +18564,20 @@ ], "source": [ "from fiora.visualization.define_colors import set_light_theme\n", + "\n", "set_light_theme()\n", - "fig, fig_data = plot_top_k_performance(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], [\"k\", \"k_wo_prec\", \"k_avg\"], labels=[\"Cosine similarity\", \"Cosine similarity w/o precursor\", \"Average cosine similarity\"], colors=[sns.color_palette(\"Paired\")[1], sns.color_palette(\"Paired\")[0], \"red\"], ratio=(7.2, 6), ylim=(0.35, 1))\n", + "fig, fig_data = plot_top_k_performance(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"],\n", + " [\"k\", \"k_wo_prec\", \"k_avg\"],\n", + " labels=[\n", + " \"Cosine similarity\",\n", + " \"Cosine similarity w/o precursor\",\n", + " \"Average cosine similarity\",\n", + " ],\n", + " colors=[sns.color_palette(\"Paired\")[1], sns.color_palette(\"Paired\")[0], \"red\"],\n", + " ratio=(7.2, 6),\n", + " ylim=(0.35, 1),\n", + ")\n", "# fig.savefig(f\"{home}/images/paper/top_k_scoring_comparison.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/top_k_scoring_comparison.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/top_k_scoring_comparison.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -17869,24 +18593,38 @@ "source": [ "if False:\n", " # Step 1: Drop duplicate group_id entries\n", - " df_unique = df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"].drop_duplicates(subset=[\"group_id\"]).copy()\n", + " df_unique = (\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"]\n", + " .drop_duplicates(subset=[\"group_id\"])\n", + " .copy()\n", + " )\n", "\n", " # Step 2: Explode the candidates list into individual rows\n", " df_exploded = df_unique.explode(\"candidates\", ignore_index=True)\n", "\n", " # Step 3: Ensure Metabolite creation\n", - " df_exploded[\"Metabolite\"] = df_exploded[\"candidates\"].apply(safe_metabolite_creation)\n", - " assert df_exploded.groupby(\"group_id\").size().eq(49).all(), \"Some group_id is missing candidates\"\n", + " df_exploded[\"Metabolite\"] = df_exploded[\"candidates\"].apply(\n", + " safe_metabolite_creation\n", + " )\n", + " assert df_exploded.groupby(\"group_id\").size().eq(49).all(), (\n", + " \"Some group_id is missing candidates\"\n", + " )\n", "\n", " # Step 4: Add idx with c_1, c_2 suffixes\n", " df_exploded[\"idx\"] = (\n", - " \"spec\" + df_exploded[\"group_id\"].astype(int).astype(str) +\n", - " \"_c_\" + (df_exploded.groupby(\"group_id\").cumcount() + 1).astype(str)\n", + " \"spec\"\n", + " + df_exploded[\"group_id\"].astype(int).astype(str)\n", + " + \"_c_\"\n", + " + (df_exploded.groupby(\"group_id\").cumcount() + 1).astype(str)\n", " )\n", "\n", " # Step 5: Calculate formula and InChIKey, add empty strings for None values\n", - " df_exploded[\"formula\"] = df_exploded[\"Metabolite\"].apply(lambda x: x.Formula if x else \"\")\n", - " df_exploded[\"InChIKey\"] = df_exploded[\"Metabolite\"].apply(lambda x: x.InChIKey if x else \"\")\n", + " df_exploded[\"formula\"] = df_exploded[\"Metabolite\"].apply(\n", + " lambda x: x.Formula if x else \"\"\n", + " )\n", + " df_exploded[\"InChIKey\"] = df_exploded[\"Metabolite\"].apply(\n", + " lambda x: x.InChIKey if x else \"\"\n", + " )\n", "\n", " # Step 6: Map column names to ICEBERG format\n", " label_map = {\n", @@ -17901,28 +18639,44 @@ " # Step 7: Save in ICEBERG-compatible format\n", " df_exploded.rename(columns=label_map)[\n", " [\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]\n", - " ].to_csv(f\"{home}/data/metabolites/ms-pred/df_test_candidates.tsv\", index=False, sep=\"\\t\")\n", + " ].to_csv(\n", + " f\"{home}/data/metabolites/ms-pred/df_test_candidates.tsv\", index=False, sep=\"\\t\"\n", + " )\n", "\n", "if False:\n", " # Step 1: Drop duplicate ChallengeName entries\n", - " df_unique = df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"].drop_duplicates(subset=[\"ChallengeName\"]).copy()\n", + " df_unique = (\n", + " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"]\n", + " .drop_duplicates(subset=[\"ChallengeName\"])\n", + " .copy()\n", + " )\n", "\n", " # Step 2: Explode the candidates list into individual rows\n", " df_exploded = df_unique.explode(\"candidates\", ignore_index=True)\n", "\n", " # Step 3: Ensure Metabolite creation\n", - " df_exploded[\"Metabolite\"] = df_exploded[\"candidates\"].apply(safe_metabolite_creation)\n", - " assert df_exploded.groupby(\"ChallengeName\").size().eq(49).all(), \"Some ChallengeName is missing candidates\"\n", + " df_exploded[\"Metabolite\"] = df_exploded[\"candidates\"].apply(\n", + " safe_metabolite_creation\n", + " )\n", + " assert df_exploded.groupby(\"ChallengeName\").size().eq(49).all(), (\n", + " \"Some ChallengeName is missing candidates\"\n", + " )\n", "\n", " # Step 4: Add idx with c_1, c_2 suffixes\n", " df_exploded[\"idx\"] = (\n", - " \"spec\" + df_exploded[\"ChallengeName\"].astype(str) +\n", - " \"_c_\" + (df_exploded.groupby(\"ChallengeName\").cumcount() + 1).astype(str)\n", + " \"spec\"\n", + " + df_exploded[\"ChallengeName\"].astype(str)\n", + " + \"_c_\"\n", + " + (df_exploded.groupby(\"ChallengeName\").cumcount() + 1).astype(str)\n", " )\n", "\n", " # Step 5: Calculate formula and InChIKey, add empty strings for None values\n", - " df_exploded[\"formula\"] = df_exploded[\"Metabolite\"].apply(lambda x: x.Formula if x else \"\")\n", - " df_exploded[\"InChIKey\"] = df_exploded[\"Metabolite\"].apply(lambda x: x.InChIKey if x else \"\")\n", + " df_exploded[\"formula\"] = df_exploded[\"Metabolite\"].apply(\n", + " lambda x: x.Formula if x else \"\"\n", + " )\n", + " df_exploded[\"InChIKey\"] = df_exploded[\"Metabolite\"].apply(\n", + " lambda x: x.InChIKey if x else \"\"\n", + " )\n", "\n", " # Step 6: Map column names to ICEBERG format\n", " label_map = {\n", @@ -17937,8 +18691,9 @@ " # Step 7: Save in ICEBERG-compatible format\n", " df_exploded.rename(columns=label_map)[\n", " [\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]\n", - " ].to_csv(f\"{home}/data/metabolites/ms-pred/df_cas_candidates.tsv\", index=False, sep=\"\\t\")\n", - "\n" + " ].to_csv(\n", + " f\"{home}/data/metabolites/ms-pred/df_cas_candidates.tsv\", index=False, sep=\"\\t\"\n", + " )" ] }, { @@ -17956,18 +18711,36 @@ " spec = row[\"peaks\"]\n", " group_id = int(row[\"group_id\"])\n", " eq_c = row[\"candidate_scores\"][\"eq_c\"]\n", - " \n", + "\n", " df_candidate_matches = df_ice[df_ice[\"group_id\"] == group_id]\n", - " df_candidate_matches = df_candidate_matches[~df_candidate_matches[\"c\"].isin([c+1 for c in eq_c])]\n", - " \n", - " ssc = list(df_candidate_matches[\"peaks\"].apply(lambda c_spec: spectral_cosine(c_spec, spec, transform=np.sqrt)))\n", - " sscwop = list(df_candidate_matches[\"peaks\"].apply(lambda c_spec: spectral_cosine(c_spec, spec, transform=np.sqrt, remove_mz=row[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=row[\"Metabolite\"].metadata[\"precursor_mode\"]))))\n", + " df_candidate_matches = df_candidate_matches[\n", + " ~df_candidate_matches[\"c\"].isin([c + 1 for c in eq_c])\n", + " ]\n", + "\n", + " ssc = list(\n", + " df_candidate_matches[\"peaks\"].apply(\n", + " lambda c_spec: spectral_cosine(c_spec, spec, transform=np.sqrt)\n", + " )\n", + " )\n", + " sscwop = list(\n", + " df_candidate_matches[\"peaks\"].apply(\n", + " lambda c_spec: spectral_cosine(\n", + " c_spec,\n", + " spec,\n", + " transform=np.sqrt,\n", + " remove_mz=row[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " ion_type=row[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " )\n", + " )\n", + " )\n", " sscavg = [(ssc[i] + sscwop[i]) / 2.0 for i in range(len(ssc))]\n", - " \n", + "\n", " row[\"candidate_scores\"][\"ice_sqrt_cosine\"] = sorted(ssc, reverse=True)[:10]\n", - " row[\"candidate_scores\"][\"ice_sqrt_cosine_wo_prec\"] = sorted(sscwop, reverse=True)[:10]\n", - " row[\"candidate_scores\"][\"ice_sqrt_cosine_avg\"] = sorted(sscavg, reverse=True)[:10]\n", - "\n" + " row[\"candidate_scores\"][\"ice_sqrt_cosine_wo_prec\"] = sorted(sscwop, reverse=True)[\n", + " :10\n", + " ]\n", + " row[\"candidate_scores\"][\"ice_sqrt_cosine_avg\"] = sorted(sscavg, reverse=True)[:10]" ] }, { @@ -17985,14 +18758,14 @@ "# spec = row[\"peaks\"]\n", "# challenge = row[\"ChallengeName\"]\n", "# eq_c = row[\"candidate_scores\"][\"eq_c\"]\n", - " \n", + "\n", "# df_candidate_matches = df_ice[df_ice[\"challenge\"] == challenge]\n", "# df_candidate_matches = df_candidate_matches[~df_candidate_matches[\"c\"].isin([c+1 for c in eq_c])]\n", - " \n", + "\n", "# ssc = list(df_candidate_matches[\"peaks\"].apply(lambda c_spec: spectral_cosine(c_spec, spec, transform=np.sqrt)))\n", "# sscwop = list(df_candidate_matches[\"peaks\"].apply(lambda c_spec: spectral_cosine(c_spec, spec, transform=np.sqrt, remove_mz=row[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=row[\"Metabolite\"].metadata[\"precursor_mode\"]))))\n", "# sscavg = [(ssc[i] + sscwop[i]) / 2.0 for i in range(len(ssc))]\n", - " \n", + "\n", "# row[\"candidate_scores\"][\"ice_sqrt_cosine\"] = sorted(ssc, reverse=True)[:10]\n", "# row[\"candidate_scores\"][\"ice_sqrt_cosine_wo_prec\"] = sorted(sscwop, reverse=True)[:10]\n", "# row[\"candidate_scores\"][\"ice_sqrt_cosine_avg\"] = sorted(sscavg, reverse=True)[:10]" @@ -18004,12 +18777,14 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "pos_mask = df_test[\"Precursor_type\"] == \"[M+H]+\"\n", - "df_test.loc[pos_mask, \"k_ice\"] = df_test.loc[pos_mask].apply(lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine\"), axis=1)\n", - "df_test.loc[pos_mask, \"k_wo_prec_ice\"] = df_test.loc[pos_mask].apply(lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine_wo_prec\"), axis=1)\n", - "# df_test.loc[pos_mask, \"k_avg_ice\"] = df_test.loc[pos_mask].apply(lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine_avg\"), axis=1)\n", - "\n" + "df_test.loc[pos_mask, \"k_ice\"] = df_test.loc[pos_mask].apply(\n", + " lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine\"), axis=1\n", + ")\n", + "df_test.loc[pos_mask, \"k_wo_prec_ice\"] = df_test.loc[pos_mask].apply(\n", + " lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine_wo_prec\"), axis=1\n", + ")\n", + "# df_test.loc[pos_mask, \"k_avg_ice\"] = df_test.loc[pos_mask].apply(lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine_avg\"), axis=1)" ] }, { @@ -18042,7 +18817,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "# pos_mask = df_cas[\"Precursor_type\"] == \"[M+H]+\"\n", "# df_cas.loc[pos_mask, \"k_ice\"] = df_cas.loc[pos_mask].apply(lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine\"), axis=1)\n", "# df_cas.loc[pos_mask, \"k_wo_prec_ice\"] = df_cas.loc[pos_mask].apply(lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine_wo_prec\"), axis=1)\n", @@ -18085,7 +18859,14 @@ } ], "source": [ - "fig, fig_data = plot_top_k_performance(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], [\"k\", \"k_ice\"], labels=[\"Fiora\", \"ICEBERG\"], ylim=(0.2,1), colors=[lightblue_hex, lightpink_hex], ratio=(7.2,6))\n", + "fig, fig_data = plot_top_k_performance(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"],\n", + " [\"k\", \"k_ice\"],\n", + " labels=[\"Fiora\", \"ICEBERG\"],\n", + " ylim=(0.2, 1),\n", + " colors=[lightblue_hex, lightpink_hex],\n", + " ratio=(7.2, 6),\n", + ")\n", "# fig.savefig(f\"{home}/images/paper/top_k.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/top_k.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/top_k.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -18110,7 +18891,14 @@ } ], "source": [ - "fig, fig_data = plot_top_k_performance(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], [\"k_wo_prec\", \"k_wo_prec_ice\"], labels=[\"Fiora\", \"ICEBERG\"],ylim=(0.2,1), ratio=(7.2,6), colors=[lightblue_hex, lightpink_hex])\n", + "fig, fig_data = plot_top_k_performance(\n", + " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"],\n", + " [\"k_wo_prec\", \"k_wo_prec_ice\"],\n", + " labels=[\"Fiora\", \"ICEBERG\"],\n", + " ylim=(0.2, 1),\n", + " ratio=(7.2, 6),\n", + " colors=[lightblue_hex, lightpink_hex],\n", + ")\n", "# fig.savefig(f\"{home}/images/paper/top_k_wo_prec.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/top_k_wo_prec.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/top_k_wo_prec.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -18151,7 +18939,12 @@ "outputs": [], "source": [ "df_test[\"group_id\"] = df_test[\"group_id\"].astype(int)\n", - "df_test.drop_duplicates(\"group_id\", keep=\"first\")[[\"group_id\", \"SMILES\"]].to_csv(f\"{home}/data/metabolites/benchmarking/classyfire_input.csv\", header=None, sep=\" \", index=False)\n", + "df_test.drop_duplicates(\"group_id\", keep=\"first\")[[\"group_id\", \"SMILES\"]].to_csv(\n", + " f\"{home}/data/metabolites/benchmarking/classyfire_input.csv\",\n", + " header=None,\n", + " sep=\" \",\n", + " index=False,\n", + ")\n", "# Use classyfire via text interface to produce output csv: http://classyfire.wishartlab.com/#chemical-text-query" ] }, @@ -18161,9 +18954,15 @@ "metadata": {}, "outputs": [], "source": [ - "compound_classes = pd.read_csv(f\"{home}/data/metabolites/benchmarking/classyfire_output.csv\")\n", - "compound_classes[\"CompoundID\"] = pd.to_numeric(compound_classes[\"CompoundID\"], errors=\"coerce\", downcast=\"integer\")\n", - "compound_classes[['Category', 'Value']] = compound_classes['ClassifiedResults'].str.split(':', n=1, expand=True)" + "compound_classes = pd.read_csv(\n", + " f\"{home}/data/metabolites/benchmarking/classyfire_output.csv\"\n", + ")\n", + "compound_classes[\"CompoundID\"] = pd.to_numeric(\n", + " compound_classes[\"CompoundID\"], errors=\"coerce\", downcast=\"integer\"\n", + ")\n", + "compound_classes[[\"Category\", \"Value\"]] = compound_classes[\n", + " \"ClassifiedResults\"\n", + "].str.split(\":\", n=1, expand=True)" ] }, { @@ -18172,9 +18971,12 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "compound_classes[\"Value\"] = compound_classes[\"Value\"].fillna('')\n", - "compound_classes = compound_classes.groupby(['CompoundID', 'Category'])['Value'].agg(','.join).unstack()\n", + "compound_classes[\"Value\"] = compound_classes[\"Value\"].fillna(\"\")\n", + "compound_classes = (\n", + " compound_classes.groupby([\"CompoundID\", \"Category\"])[\"Value\"]\n", + " .agg(\",\".join)\n", + " .unstack()\n", + ")\n", "compound_classes.reset_index(inplace=True)\n", "compound_classes.columns.name = None" ] @@ -18317,7 +19119,9 @@ "outputs": [], "source": [ "num_classes = len(compound_classes[\"Superclass\"].unique())\n", - "superclass_map = dict(zip(compound_classes[\"Superclass\"].unique(), range(0, num_classes)))" + "superclass_map = dict(\n", + " zip(compound_classes[\"Superclass\"].unique(), range(0, num_classes))\n", + ")" ] }, { @@ -18326,10 +19130,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_em = {\n", - " \"embedding\": [],\n", - " \"spectrum\": []\n", - "}" + "df_em = {\"embedding\": [], \"spectrum\": []}" ] }, { @@ -18352,9 +19153,15 @@ "for i, d in df_test.drop_duplicates(\"group_id\", keep=\"first\").iterrows():\n", " metabolite = d[\"Metabolite\"]\n", " group_id = d[\"group_id\"]\n", - " superclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\"Superclass\"]\n", - " subclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\"Subclass\"]\n", - " cclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\"Class\"]\n", + " superclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\n", + " \"Superclass\"\n", + " ]\n", + " subclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\n", + " \"Subclass\"\n", + " ]\n", + " cclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\n", + " \"Class\"\n", + " ]\n", "\n", " supermap[group_id] = superclass\n", " classmap[group_id] = cclass\n", @@ -18363,10 +19170,10 @@ " data = metabolite.as_geometric_data(with_labels=False).to(dev)\n", " batch = geom.data.Batch.from_data_list([data])\n", " embedding = model.get_graph_embedding(batch)\n", - " if d[\"Precursor_type\"] == \"[M+H]+\": \n", + " if d[\"Precursor_type\"] == \"[M+H]+\":\n", " df_em[\"embedding\"] += [embedding.flatten().cpu().detach().numpy()]\n", " df_em[\"spectrum\"] += [d[\"peaks\"]]\n", - " \n", + "\n", " test_group_id += [group_id]\n", " test_smiles += [d[\"SMILES\"]]\n", " test_embeddings += [embedding.flatten().cpu().detach().numpy()]\n", @@ -18378,9 +19185,7 @@ "\n", "df_test[\"Superclass\"] = df_test[\"group_id\"].map(supermap)\n", "df_test[\"Class\"] = df_test[\"group_id\"].map(classmap)\n", - "df_test[\"Subclass\"] = df_test[\"group_id\"].map(submap)\n", - "\n", - "\n" + "df_test[\"Subclass\"] = df_test[\"group_id\"].map(submap)" ] }, { @@ -18389,13 +19194,15 @@ "metadata": {}, "outputs": [], "source": [ - "Embedding_DF = pd.DataFrame({\n", - " \"group_id\": test_group_id,\n", - " \"SMILES\": test_smiles,\n", - " \"Superclass\": test_classes,\n", - " \"Compound Class\": test_cclasses,\n", - " \"Embedding\": test_embeddings\n", - "})" + "Embedding_DF = pd.DataFrame(\n", + " {\n", + " \"group_id\": test_group_id,\n", + " \"SMILES\": test_smiles,\n", + " \"Superclass\": test_classes,\n", + " \"Compound Class\": test_cclasses,\n", + " \"Embedding\": test_embeddings,\n", + " }\n", + ")" ] }, { @@ -18404,7 +19211,7 @@ "metadata": {}, "outputs": [], "source": [ - "#Embedding_DF.to_excel(f\"{home}/images/paper/F4a.xlsx\")" + "# Embedding_DF.to_excel(f\"{home}/images/paper/F4a.xlsx\")" ] }, { @@ -18413,7 +19220,7 @@ "metadata": {}, "outputs": [], "source": [ - "#Embedding_DF[Embedding_DF[\"Superclass\"] == \" Lipids and lipid-like molecules\"].to_excel(f\"{home}/images/paper/SF18and19.xlsx\")" + "# Embedding_DF[Embedding_DF[\"Superclass\"] == \" Lipids and lipid-like molecules\"].to_excel(f\"{home}/images/paper/SF18and19.xlsx\")" ] }, { @@ -18423,6 +19230,7 @@ "outputs": [], "source": [ "from fiora.MS.spectral_scores import spectral_cosine, cosine\n", + "\n", "em_sim = []\n", "spec_sim = []\n", "\n", @@ -18435,8 +19243,7 @@ " em2 = df_em[\"embedding\"][j]\n", " spec2 = df_em[\"spectrum\"][j]\n", " spec_sim += [spectral_cosine(spec, spec2, transform=np.sqrt)]\n", - " em_sim += [cosine(em, em2)]\n", - " " + " em_sim += [cosine(em, em2)]" ] }, { @@ -18456,7 +19263,7 @@ } ], "source": [ - "np.corrcoef(em_sim, spec_sim, dtype=float)[0,1]" + "np.corrcoef(em_sim, spec_sim, dtype=float)[0, 1]" ] }, { @@ -18488,7 +19295,7 @@ "metadata": {}, "outputs": [], "source": [ - "lipid_index = np.where(np.array(test_classes) == ' Lipids and lipid-like molecules')" + "lipid_index = np.where(np.array(test_classes) == \" Lipids and lipid-like molecules\")" ] }, { @@ -18568,8 +19375,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "\n", "# for i in range(5, 31, 5): # 7 or 8 neighbors or 20 to 25 neighbors\n", "# print(i)\n", "# reducer = umap.UMAP(n_neighbors=i, min_dist=0.1, metric=\"euclidean\", random_state=0, n_jobs=1) # 43, 50, 51, 52, 54! looks nice\n", @@ -18646,20 +19451,30 @@ "import umap\n", "\n", "reset_matplotlib()\n", - "reducer = umap.UMAP(n_neighbors=20, min_dist=0.1, random_state=0, n_jobs=1) # creates different umaps despite fixed seed. Reason probably pytorch.\n", + "reducer = umap.UMAP(\n", + " n_neighbors=20, min_dist=0.1, random_state=0, n_jobs=1\n", + ") # creates different umaps despite fixed seed. Reason probably pytorch.\n", "reducer.fit(test_embeddings)\n", "e = reducer.transform(test_embeddings)\n", "\n", - "fig, ax = plt.subplots(1,1, figsize=(8,8))\n", - "scatter = sns.scatterplot(ax=ax,\n", + "fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n", + "scatter = sns.scatterplot(\n", + " ax=ax,\n", " x=e[:, 0],\n", " y=e[:, 1],\n", - " hue=test_classes, edgecolor=\"white\", linewidth=0.48, s=40, palette=sns.color_palette(\"husl\", 13), style=test_classes, markers=[\"o\", (4,0, 45), \"D\", \"v\", \"p\", (4,1,0), \"X\"])# markers=[\"o\", \"*\", \"s\", \"^\", \"D\", \"h\", (4,1,0),\"v\", \"X\", \"P\", \"p\", \"<\", (4,0, 45)] s=30,, hue_order=np.unique(test_classes)[::-1])#, order=[' Organic acids and derivatives', ' Organoheterocyclic compounds', ' Benzenoids', ' Alkaloids and derivatives', ' Phenylpropanoids and polyketides', ])#, palette=sns.color_palette(\"colorblind\") + [\"black\", \"gray\", \"white\"])\n", - "legend = ax.legend(loc='lower left', bbox_to_anchor=(1, 0.5))\n", - "#plt.gca().set_aspect('equal', 'datalim')\n", - "#plt.ylim([-2.50,13])\n", - "#plt.xlim([-2.50,13])\n", - "ax.set_aspect('equal', 'datalim')\n", + " hue=test_classes,\n", + " edgecolor=\"white\",\n", + " linewidth=0.48,\n", + " s=40,\n", + " palette=sns.color_palette(\"husl\", 13),\n", + " style=test_classes,\n", + " markers=[\"o\", (4, 0, 45), \"D\", \"v\", \"p\", (4, 1, 0), \"X\"],\n", + ") # markers=[\"o\", \"*\", \"s\", \"^\", \"D\", \"h\", (4,1,0),\"v\", \"X\", \"P\", \"p\", \"<\", (4,0, 45)] s=30,, hue_order=np.unique(test_classes)[::-1])#, order=[' Organic acids and derivatives', ' Organoheterocyclic compounds', ' Benzenoids', ' Alkaloids and derivatives', ' Phenylpropanoids and polyketides', ])#, palette=sns.color_palette(\"colorblind\") + [\"black\", \"gray\", \"white\"])\n", + "legend = ax.legend(loc=\"lower left\", bbox_to_anchor=(1, 0.5))\n", + "# plt.gca().set_aspect('equal', 'datalim')\n", + "# plt.ylim([-2.50,13])\n", + "# plt.xlim([-2.50,13])\n", + "ax.set_aspect(\"equal\", \"datalim\")\n", "print(ax.get_xlim())\n", "ax.set_xlim((4.18, 14.15))\n", "ax.set_ylim(ax.get_xlim())\n", @@ -18668,7 +19483,7 @@ "\n", "# Print the default marker size\n", "print(\"Default marker size:\", default_marker_size)\n", - "#ax.set_ylim([4, 12])\n", + "# ax.set_ylim([4, 12])\n", "# fig.savefig(f\"{home}/images/paper/umap_alt2.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/umap_alt2.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", @@ -18682,7 +19497,9 @@ "handles = legend.legend_handles\n", "\n", "# Get the facecolors from the legend handles (for scatter plot markers)\n", - "legend_colors = [handle.get_color() for handle in handles]# if isinstance(handle, matplotlib.patches.PathPatch)]\n", + "legend_colors = [\n", + " handle.get_color() for handle in handles\n", + "] # if isinstance(handle, matplotlib.patches.PathPatch)]\n", "\n", "\n", "plt.show()" @@ -18735,27 +19552,37 @@ "import umap\n", "\n", "reset_matplotlib()\n", - "reducer = umap.UMAP(n_neighbors=20, min_dist=0.1, random_state=0, n_jobs=1) # creates different umaps despite fixed seed. Reason probably pytorch.\n", + "reducer = umap.UMAP(\n", + " n_neighbors=20, min_dist=0.1, random_state=0, n_jobs=1\n", + ") # creates different umaps despite fixed seed. Reason probably pytorch.\n", "reducer.fit(test_embeddings)\n", "e = reducer.transform(test_embeddings)\n", "\n", - "fig, ax = plt.subplots(1,1, figsize=(8,8))\n", - "scatter = sns.scatterplot(ax=ax,\n", + "fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n", + "scatter = sns.scatterplot(\n", + " ax=ax,\n", " x=e[lipid_index[0], 0],\n", " y=e[lipid_index[0], 1],\n", - " hue=np.array(test_cclasses)[lipid_index[0]], edgecolor=\"white\", linewidth=0.48, s=40, palette=sns.color_palette(\"husl\", 6), style=np.array(test_cclasses)[lipid_index[0]], markers=[\"o\", (4,0, 45), \"D\", \"v\", \"p\", (4,1,0), \"X\"])# markers=[\"o\", \"*\", \"s\", \"^\", \"D\", \"h\", (4,1,0),\"v\", \"X\", \"P\", \"p\", \"<\", (4,0, 45)] s=30,, hue_order=np.unique(test_classes)[::-1])#, order=[' Organic acids and derivatives', ' Organoheterocyclic compounds', ' Benzenoids', ' Alkaloids and derivatives', ' Phenylpropanoids and polyketides', ])#, palette=sns.color_palette(\"colorblind\") + [\"black\", \"gray\", \"white\"])\n", - "legend = ax.legend(loc='lower left', bbox_to_anchor=(1, 0.5))\n", - "#plt.gca().set_aspect('equal', 'datalim')\n", - "#plt.ylim([-2.50,13])\n", - "#plt.xlim([-2.50,13])\n", - "ax.set_aspect('equal', 'datalim')\n", + " hue=np.array(test_cclasses)[lipid_index[0]],\n", + " edgecolor=\"white\",\n", + " linewidth=0.48,\n", + " s=40,\n", + " palette=sns.color_palette(\"husl\", 6),\n", + " style=np.array(test_cclasses)[lipid_index[0]],\n", + " markers=[\"o\", (4, 0, 45), \"D\", \"v\", \"p\", (4, 1, 0), \"X\"],\n", + ") # markers=[\"o\", \"*\", \"s\", \"^\", \"D\", \"h\", (4,1,0),\"v\", \"X\", \"P\", \"p\", \"<\", (4,0, 45)] s=30,, hue_order=np.unique(test_classes)[::-1])#, order=[' Organic acids and derivatives', ' Organoheterocyclic compounds', ' Benzenoids', ' Alkaloids and derivatives', ' Phenylpropanoids and polyketides', ])#, palette=sns.color_palette(\"colorblind\") + [\"black\", \"gray\", \"white\"])\n", + "legend = ax.legend(loc=\"lower left\", bbox_to_anchor=(1, 0.5))\n", + "# plt.gca().set_aspect('equal', 'datalim')\n", + "# plt.ylim([-2.50,13])\n", + "# plt.xlim([-2.50,13])\n", + "ax.set_aspect(\"equal\", \"datalim\")\n", "# Get the default marker size\n", "default_marker_size = ax.collections[0].get_sizes()[0]\n", "\n", "# Print the default marker size\n", "print(\"Default marker size:\", default_marker_size)\n", - "#ax.set_ylim([4, 12])\n", - "#fig.savefig(f\"{home}/images/paper/umap_lipids_global.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", + "# ax.set_ylim([4, 12])\n", + "# fig.savefig(f\"{home}/images/paper/umap_lipids_global.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", "# Get the default line width (edges)\n", "default_line_width = ax.collections[0].get_linewidths()[0]\n", @@ -18767,7 +19594,9 @@ "handles = legend.legend_handles\n", "\n", "# Get the facecolors from the legend handles (for scatter plot markers)\n", - "legend_colors = [handle.get_color() for handle in handles]# if isinstance(handle, matplotlib.patches.PathPatch)]\n", + "legend_colors = [\n", + " handle.get_color() for handle in handles\n", + "] # if isinstance(handle, matplotlib.patches.PathPatch)]\n", "\n", "# fig.savefig(f\"{home}/images/paper/umap_lipids_global.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/umap_lipids_global.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -18803,8 +19632,10 @@ } ], "source": [ - "#TODO use as mask to plot lipids\n", - "df_test.drop_duplicates(\"group_id\", keep=\"first\")[\"Superclass\"] == \" Lipids and lipid-like molecules\"" + "# TODO use as mask to plot lipids\n", + "df_test.drop_duplicates(\"group_id\", keep=\"first\")[\n", + " \"Superclass\"\n", + "] == \" Lipids and lipid-like molecules\"" ] }, { @@ -18865,20 +19696,32 @@ ], "source": [ "from fiora.visualization.define_colors import adjust_box_widths\n", - "fig, ax = plt.subplots(1,1, figsize=(10, 6))\n", "\n", - "ax = sns.boxplot(ax=ax, data=df_test[df_test[\"Superclass\"] != \"nan\"], y=\"spectral_sqrt_cosine\", dodge=True, width=0.9, linewidth=1.5, hue=\"Superclass\", legend=False, palette=legend_colors, showfliers=False)\n", + "fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n", + "\n", + "ax = sns.boxplot(\n", + " ax=ax,\n", + " data=df_test[df_test[\"Superclass\"] != \"nan\"],\n", + " y=\"spectral_sqrt_cosine\",\n", + " dodge=True,\n", + " width=0.9,\n", + " linewidth=1.5,\n", + " hue=\"Superclass\",\n", + " legend=False,\n", + " palette=legend_colors,\n", + " showfliers=False,\n", + ")\n", "ax.set_ylabel(\"cosine similarity\", fontsize=14)\n", - "ax.set_ylim([0,1])\n", - "plt.tick_params(axis='y', labelsize=14) \n", + "ax.set_ylim([0, 1])\n", + "plt.tick_params(axis=\"y\", labelsize=14)\n", "ax.set_xticks([])\n", "sns.despine(offset=10, trim=True)\n", - "ax.spines['bottom'].set_visible(False)\n", + "ax.spines[\"bottom\"].set_visible(False)\n", "\n", "adjust_box_widths(ax, 0.85)\n", "\n", - "#fig.savefig(f\"{home}/images/paper/cosine_by_class.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - "#fig.savefig(f\"{home}/images/paper/cosine_by_class.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", + "# fig.savefig(f\"{home}/images/paper/cosine_by_class.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", + "# fig.savefig(f\"{home}/images/paper/cosine_by_class.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "plt.show()" ] }, @@ -18900,51 +19743,86 @@ ], "source": [ "from fiora.visualization.define_colors import adjust_box_widths, adjust_bar_widths\n", - "fig, axs = plt.subplots(2,1, figsize=(10, 9), sharex=True, height_ratios=[1,4])\n", - "plt.subplots_adjust(hspace=0.1)#top=0.94, bottom=0.12, right=0.97, left=0.08)\n", + "\n", + "fig, axs = plt.subplots(2, 1, figsize=(10, 9), sharex=True, height_ratios=[1, 4])\n", + "plt.subplots_adjust(hspace=0.1) # top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", "superclass_data = {}\n", "for i, superclass in enumerate(df_test[\"Superclass\"].unique()):\n", " if superclass != \"nan\":\n", - " superclass_data[i] = {\"Superclass\": superclass, \"num_spectra\": sum(df_test[\"Superclass\"] == superclass), \"num_compounds\": df_test[df_test[\"Superclass\"] == superclass][\"group_id\"].nunique()}\n", + " superclass_data[i] = {\n", + " \"Superclass\": superclass,\n", + " \"num_spectra\": sum(df_test[\"Superclass\"] == superclass),\n", + " \"num_compounds\": df_test[df_test[\"Superclass\"] == superclass][\n", + " \"group_id\"\n", + " ].nunique(),\n", + " }\n", "\n", "superclass_data = pd.DataFrame(superclass_data).transpose()\n", "\n", - "axs[0] = sns.barplot(ax=axs[0], data=superclass_data,y=\"num_compounds\", dodge=True, width=0.9, edgecolor=\"black\", linewidth=1.5, hue=\"Superclass\", legend=False, palette=legend_colors)\n", + "axs[0] = sns.barplot(\n", + " ax=axs[0],\n", + " data=superclass_data,\n", + " y=\"num_compounds\",\n", + " dodge=True,\n", + " width=0.9,\n", + " edgecolor=\"black\",\n", + " linewidth=1.5,\n", + " hue=\"Superclass\",\n", + " legend=False,\n", + " palette=legend_colors,\n", + ")\n", "for i, container in enumerate(axs[0].containers):\n", " axs[0].bar_label(axs[0].containers[i], fontsize=18)\n", " for bar in container:\n", " bar_label = f\"n={superclass_data.iloc[i]['num_spectra']}\"\n", - " axs[0].text(bar.get_x() + bar.get_width() / 2, -0.07, # Position at the base (y=0)\n", - " bar_label, ha='center', va='top', fontsize=14)\n", + " axs[0].text(\n", + " bar.get_x() + bar.get_width() / 2,\n", + " -0.07, # Position at the base (y=0)\n", + " bar_label,\n", + " ha=\"center\",\n", + " va=\"top\",\n", + " fontsize=14,\n", + " )\n", "\n", "axs[0].set_ylabel(\"\", fontsize=14)\n", - "axs[0].spines['bottom'].set_visible(False)\n", - "axs[0].tick_params(axis='y', labelsize=14) \n", + "axs[0].spines[\"bottom\"].set_visible(False)\n", + "axs[0].tick_params(axis=\"y\", labelsize=14)\n", "adjust_bar_widths(axs[0], 0.85)\n", "\n", "\n", - "axs[1] = sns.boxplot(ax=axs[1], data=df_test[df_test[\"Superclass\"] != \"nan\"], y=\"spectral_sqrt_cosine\", dodge=True, width=0.9, linewidth=1.5, hue=\"Superclass\", legend=False, palette=legend_colors, showfliers=False)\n", - "axs[1].set_ylabel(\"\")#\"cosine similarity\", fontsize=14)\n", - "axs[1].set_ylim([0,1])\n", - "plt.tick_params(axis='y', labelsize=14) \n", + "axs[1] = sns.boxplot(\n", + " ax=axs[1],\n", + " data=df_test[df_test[\"Superclass\"] != \"nan\"],\n", + " y=\"spectral_sqrt_cosine\",\n", + " dodge=True,\n", + " width=0.9,\n", + " linewidth=1.5,\n", + " hue=\"Superclass\",\n", + " legend=False,\n", + " palette=legend_colors,\n", + " showfliers=False,\n", + ")\n", + "axs[1].set_ylabel(\"\") # \"cosine similarity\", fontsize=14)\n", + "axs[1].set_ylim([0, 1])\n", + "plt.tick_params(axis=\"y\", labelsize=14)\n", "axs[1].set_xticks([])\n", "sns.despine(offset=10, trim=True)\n", - "axs[1].spines['bottom'].set_visible(False)\n", + "axs[1].spines[\"bottom\"].set_visible(False)\n", "adjust_box_widths(axs[1], 0.85)\n", "\n", - "plt.rc('axes', labelsize=18)\n", - "plt.rc('legend', fontsize=18)\n", + "plt.rc(\"axes\", labelsize=18)\n", + "plt.rc(\"legend\", fontsize=18)\n", "\n", - "axs[0].tick_params(axis='both', labelsize=18)\n", - "axs[1].tick_params(axis='both', labelsize=18)\n", + "axs[0].tick_params(axis=\"both\", labelsize=18)\n", + "axs[1].tick_params(axis=\"both\", labelsize=18)\n", "\n", "\n", "# fig.savefig(f\"{home}/images/paper/cosine_by_class_+hist.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/cosine_by_class_+hist.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/cosine_by_class_+hist.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -18962,7 +19840,9 @@ "metadata": {}, "outputs": [], "source": [ - "df_test[[\"group_id\", \"Superclass\", \"spectral_sqrt_cosine\"]].to_excel(f\"{home}/images/paper/F4c.xlsx\")" + "df_test[[\"group_id\", \"Superclass\", \"spectral_sqrt_cosine\"]].to_excel(\n", + " f\"{home}/images/paper/F4c.xlsx\"\n", + ")" ] }, { @@ -18982,14 +19862,13 @@ } ], "source": [ - "\n", "# print(df_test.groupby(\"Superclass\").group_id.unique().apply(len))\n", "# print(df_test.groupby(\"Superclass\").spectral_sqrt_cosine.median())\n", "# #df_test.groupby(\"Superclass\").spectral_sqrt_cosine.median()\n", "result = df_test.groupby(\"Superclass\").agg(\n", " num=(\"group_id\", lambda x: len(x.unique())),\n", " spec=(\"group_id\", lambda x: len(x)),\n", - " cos=(\"spectral_sqrt_cosine\", \"median\")\n", + " cos=(\"spectral_sqrt_cosine\", \"median\"),\n", ")\n", "print(result)" ] @@ -19012,19 +19891,31 @@ ], "source": [ "from fiora.visualization.define_colors import adjust_box_widths\n", - "fig, ax = plt.subplots(1,1, figsize=(10, 6))\n", "\n", - "ax = sns.boxplot(ax=ax, data=df_test[df_test[\"Class\"] != \"nan\"], y=\"spectral_sqrt_cosine\", dodge=True, width=0.9, linewidth=1.5, hue=\"Class\", legend=False, palette=legend_colors, showfliers=False)\n", + "fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n", + "\n", + "ax = sns.boxplot(\n", + " ax=ax,\n", + " data=df_test[df_test[\"Class\"] != \"nan\"],\n", + " y=\"spectral_sqrt_cosine\",\n", + " dodge=True,\n", + " width=0.9,\n", + " linewidth=1.5,\n", + " hue=\"Class\",\n", + " legend=False,\n", + " palette=legend_colors,\n", + " showfliers=False,\n", + ")\n", "ax.set_ylabel(\"cosine similarity\", fontsize=14)\n", - "ax.set_ylim([0,1])\n", - "plt.tick_params(axis='y', labelsize=14) \n", + "ax.set_ylim([0, 1])\n", + "plt.tick_params(axis=\"y\", labelsize=14)\n", "ax.set_xticks([])\n", "sns.despine(offset=10, trim=True)\n", - "ax.spines['bottom'].set_visible(False)\n", + "ax.spines[\"bottom\"].set_visible(False)\n", "\n", "adjust_box_widths(fig, 0.85)\n", - "#fig.savefig(f\"{home}/images/paper/cosine_by_class.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - "#fig.savefig(f\"{home}/images/paper/cosine_by_class.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", + "# fig.savefig(f\"{home}/images/paper/cosine_by_class.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", + "# fig.savefig(f\"{home}/images/paper/cosine_by_class.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "plt.show()" ] }, @@ -19052,7 +19943,9 @@ } ], "source": [ - "df_test[df_test[\"Superclass\"] == \" Lipids and lipid-like molecules\"].drop_duplicates(\"group_id\", keep=\"first\")[\"Class\"].value_counts()" + "df_test[df_test[\"Superclass\"] == \" Lipids and lipid-like molecules\"].drop_duplicates(\n", + " \"group_id\", keep=\"first\"\n", + ")[\"Class\"].value_counts()" ] }, { @@ -19082,11 +19975,17 @@ "submap = {}\n", "pal = sns.color_palette(\"viridis\", num_classes)\n", "\n", - "for i, d in df_test[df_test[\"Superclass\"] == \" Lipids and lipid-like molecules\"].drop_duplicates(\"group_id\", keep=\"first\").iterrows():\n", + "for i, d in (\n", + " df_test[df_test[\"Superclass\"] == \" Lipids and lipid-like molecules\"]\n", + " .drop_duplicates(\"group_id\", keep=\"first\")\n", + " .iterrows()\n", + "):\n", " metabolite = d[\"Metabolite\"]\n", " group_id = d[\"group_id\"]\n", "\n", - " cclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\"Class\"]\n", + " cclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\n", + " \"Class\"\n", + " ]\n", "\n", " data = metabolite.as_geometric_data(with_labels=False).to(dev)\n", " batch = geom.data.Batch.from_data_list([data])\n", @@ -19094,7 +19993,7 @@ " test_embeddings += [embedding.flatten().cpu().detach().numpy()]\n", " test_classes += [cclass]\n", "\n", - " colors += [pal[class_map[cclass]]]\n" + " colors += [pal[class_map[cclass]]]" ] }, { @@ -19118,29 +20017,39 @@ "import umap\n", "\n", "reset_matplotlib()\n", - "reducer = umap.UMAP(n_neighbors=8, min_dist=0.1, random_state=0, n_jobs=1) # creates different umaps despite fixed seed. Reason probably pytorch.\n", + "reducer = umap.UMAP(\n", + " n_neighbors=8, min_dist=0.1, random_state=0, n_jobs=1\n", + ") # creates different umaps despite fixed seed. Reason probably pytorch.\n", "reducer.fit(test_embeddings)\n", "e = reducer.transform(test_embeddings)\n", "\n", - "fig, ax = plt.subplots(1,1, figsize=(8,8))\n", - "scatter = sns.scatterplot(ax=ax,\n", + "fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n", + "scatter = sns.scatterplot(\n", + " ax=ax,\n", " x=e[:, 0],\n", " y=e[:, 1],\n", - " hue=test_classes, edgecolor=\"white\", linewidth=0.48, s=40, palette=sns.color_palette(\"husl\", 6), style=test_classes, markers=[\"o\", (4,0, 45), \"D\", \"v\", \"p\", (4,1,0), \"X\"])# markers=[\"o\", \"*\", \"s\", \"^\", \"D\", \"h\", (4,1,0),\"v\", \"X\", \"P\", \"p\", \"<\", (4,0, 45)] s=30,, hue_order=np.unique(test_classes)[::-1])#, order=[' Organic acids and derivatives', ' Organoheterocyclic compounds', ' Benzenoids', ' Alkaloids and derivatives', ' Phenylpropanoids and polyketides', ])#, palette=sns.color_palette(\"colorblind\") + [\"black\", \"gray\", \"white\"])\n", - "legend = ax.legend(loc='lower left', bbox_to_anchor=(1, 0.5))\n", - "#plt.gca().set_aspect('equal', 'datalim')\n", - "#plt.ylim([-2.50,13])\n", - "#plt.xlim([-2.50,13])\n", - "#ax.set_aspect('equal', 'datalim')\n", - "#print(ax.get_xlim())\n", - "#ax.set_xlim((4.18, 14.15))\n", - "#ax.set_ylim(ax.get_xlim())\n", + " hue=test_classes,\n", + " edgecolor=\"white\",\n", + " linewidth=0.48,\n", + " s=40,\n", + " palette=sns.color_palette(\"husl\", 6),\n", + " style=test_classes,\n", + " markers=[\"o\", (4, 0, 45), \"D\", \"v\", \"p\", (4, 1, 0), \"X\"],\n", + ") # markers=[\"o\", \"*\", \"s\", \"^\", \"D\", \"h\", (4,1,0),\"v\", \"X\", \"P\", \"p\", \"<\", (4,0, 45)] s=30,, hue_order=np.unique(test_classes)[::-1])#, order=[' Organic acids and derivatives', ' Organoheterocyclic compounds', ' Benzenoids', ' Alkaloids and derivatives', ' Phenylpropanoids and polyketides', ])#, palette=sns.color_palette(\"colorblind\") + [\"black\", \"gray\", \"white\"])\n", + "legend = ax.legend(loc=\"lower left\", bbox_to_anchor=(1, 0.5))\n", + "# plt.gca().set_aspect('equal', 'datalim')\n", + "# plt.ylim([-2.50,13])\n", + "# plt.xlim([-2.50,13])\n", + "# ax.set_aspect('equal', 'datalim')\n", + "# print(ax.get_xlim())\n", + "# ax.set_xlim((4.18, 14.15))\n", + "# ax.set_ylim(ax.get_xlim())\n", "# Get the default marker size\n", "default_marker_size = ax.collections[0].get_sizes()[0]\n", "\n", "# Print the default marker size\n", "print(\"Default marker size:\", default_marker_size)\n", - "#ax.set_ylim([4, 12])\n", + "# ax.set_ylim([4, 12])\n", "# fig.savefig(f\"{home}/images/paper/umap_lipids_local.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/umap_lipids_local.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/umap_lipids_local.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -19155,7 +20064,9 @@ "handles = legend.legend_handles\n", "\n", "# Get the facecolors from the legend handles (for scatter plot markers)\n", - "legend_colors = [handle.get_color() for handle in handles]# if isinstance(handle, matplotlib.patches.PathPatch)]\n", + "legend_colors = [\n", + " handle.get_color() for handle in handles\n", + "] # if isinstance(handle, matplotlib.patches.PathPatch)]\n", "\n", "\n", "plt.show()" @@ -19195,29 +20106,61 @@ } ], "source": [ - "calc_tanimoto = True # This may tak a long time\n", - "if calc_tanimoto: \n", + "calc_tanimoto = True # This may tak a long time\n", + "if calc_tanimoto:\n", " print(\"Calculating Tanimoto scores. This may take a while\")\n", - " df_cas.loc[:,\"tanimoto\"] = np.nan\n", - " for i,d in df_cas.iterrows():\n", - " df_cas.at[i, \"tanimoto\"] = df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"])).max()\n", - " df_cas.at[i, \"tanimoto3\"] = df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\")).max()\n", - " \n", - " df_cas22.loc[:,\"tanimoto\"] = np.nan\n", - " for i,d in df_cas22.iterrows():\n", - " df_cas22.at[i, \"tanimoto\"] = df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"])).max()\n", - " df_cas22.at[i, \"tanimoto3\"] = df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\")).max()\n", - " \n", - " df_test.loc[:,\"tanimoto\"] = np.nan\n", - " for i,d in df_test.iterrows():\n", - " df_test.at[i, \"tanimoto\"] = df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"])).max()\n", - " df_test.at[i, \"tanimoto3\"] = df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\")).max()\n", - " \n", - " df_msnlib_test.loc[:,\"tanimoto\"] = np.nan\n", - " df_msnlib_test.loc[:,\"tanimoto3\"] = np.nan\n", - " for i,d in df_msnlib_test.iterrows():\n", - " df_msnlib_test.at[i, \"tanimoto\"] = df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"])).max()\n", - " df_msnlib_test.at[i, \"tanimoto3\"] = df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"].apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\")).max()\n" + " df_cas.loc[:, \"tanimoto\"] = np.nan\n", + " for i, d in df_cas.iterrows():\n", + " df_cas.at[i, \"tanimoto\"] = (\n", + " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " .max()\n", + " )\n", + " df_cas.at[i, \"tanimoto3\"] = (\n", + " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\"))\n", + " .max()\n", + " )\n", + "\n", + " df_cas22.loc[:, \"tanimoto\"] = np.nan\n", + " for i, d in df_cas22.iterrows():\n", + " df_cas22.at[i, \"tanimoto\"] = (\n", + " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " .max()\n", + " )\n", + " df_cas22.at[i, \"tanimoto3\"] = (\n", + " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\"))\n", + " .max()\n", + " )\n", + "\n", + " df_test.loc[:, \"tanimoto\"] = np.nan\n", + " for i, d in df_test.iterrows():\n", + " df_test.at[i, \"tanimoto\"] = (\n", + " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " .max()\n", + " )\n", + " df_test.at[i, \"tanimoto3\"] = (\n", + " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\"))\n", + " .max()\n", + " )\n", + "\n", + " df_msnlib_test.loc[:, \"tanimoto\"] = np.nan\n", + " df_msnlib_test.loc[:, \"tanimoto3\"] = np.nan\n", + " for i, d in df_msnlib_test.iterrows():\n", + " df_msnlib_test.at[i, \"tanimoto\"] = (\n", + " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " .max()\n", + " )\n", + " df_msnlib_test.at[i, \"tanimoto3\"] = (\n", + " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", + " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\"))\n", + " .max()\n", + " )" ] }, { @@ -19228,6 +20171,8 @@ "source": [ "df_test[\"group_id\"] = df_test[\"group_id\"].astype(int)\n", "new_value_offset = 100000\n", + "\n", + "\n", "# Function to assign unique metabolite identifiers\n", "def assign_metabolite_ids(df_ref: pd.DataFrame, metabolite_id_map):\n", " for i, data in df_ref.iterrows():\n", @@ -19244,6 +20189,7 @@ " df_ref.loc[i, \"group_id\"] = int(new_id)\n", " metabolite_id_map[int(new_id)] = metabolite\n", "\n", + "\n", "if calc_tanimoto:\n", " # Initialize the metabolite_id_map with metabolites from df_test\n", " metabolite_id_map = {}\n", @@ -19251,15 +20197,14 @@ " metabolite = df_test.loc[df_test[\"group_id\"] == group_id, \"Metabolite\"].iloc[0]\n", " metabolite_id_map[int(group_id)] = metabolite\n", "\n", - "\n", " # Apply the function to each dataframe\n", " assign_metabolite_ids(df_msnlib_test, metabolite_id_map)\n", " assign_metabolite_ids(df_cas, metabolite_id_map)\n", " assign_metabolite_ids(df_cas22, metabolite_id_map)\n", - " \n", + "\n", " df_msnlib_test[\"group_id\"] = df_msnlib_test[\"group_id\"].astype(int)\n", " df_cas[\"group_id\"] = df_cas[\"group_id\"].astype(int)\n", - " df_cas22[\"group_id\"] = df_cas22[\"group_id\"].astype(int)\n" + " df_cas22[\"group_id\"] = df_cas22[\"group_id\"].astype(int)" ] }, { @@ -19272,16 +20217,29 @@ "\n", "if calc_tanimoto:\n", " set_light_theme()\n", - " fig, ax = plt.subplots(1,1, figsize=(12, 6))\n", + " fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", " plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", " C = pd.concat([df_test, df_msnlib_test, df_cas, df_cas22], ignore_index=True)\n", - " sns.pointplot(ax=ax, data=C, x=pd.cut(C[C[\"Precursor_type\"] == \"[M+H]+\"]['tanimoto3'], bins=[x/10.0 for x in list(range(2,11,1))]), y=\"spectral_sqrt_cosine\", palette=sns.color_palette(\"bright\"), capsize=.0, hue=\"Dataset\", dodge=0.25, estimator=\"median\") #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + " sns.pointplot(\n", + " ax=ax,\n", + " data=C,\n", + " x=pd.cut(\n", + " C[C[\"Precursor_type\"] == \"[M+H]+\"][\"tanimoto3\"],\n", + " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", + " ),\n", + " y=\"spectral_sqrt_cosine\",\n", + " palette=sns.color_palette(\"bright\"),\n", + " capsize=0.0,\n", + " hue=\"Dataset\",\n", + " dodge=0.25,\n", + " estimator=\"median\",\n", + " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " plt.ylim([0, 1])\n", " plt.legend(title=\"Dataset\", loc=\"lower left\")\n", " plt.ylabel(\"Cosine similarity\")\n", " plt.xlabel(\"Tanimoto similarity\")\n", - " \n", + "\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_withcasmi.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_withcasmi.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_withcasmi.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -19295,17 +20253,62 @@ "outputs": [], "source": [ "if calc_tanimoto:\n", - " fig, ax = plt.subplots(1,1, figsize=(12, 6))\n", + " fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", " plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", - " \n", - " sns.pointplot(ax=ax, data=C, x=pd.cut(C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")]['tanimoto3'], bins=[x/10.0 for x in list(range(2,11,1))]), y=\"spectral_sqrt_cosine\", capsize=.0, color=lightblue_hex, estimator=\"median\") #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", - " sns.pointplot(ax=ax, data=C, x=pd.cut(C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")]['tanimoto3'], bins=[x/10.0 for x in list(range(2,11,1))]), y=\"cfm_sqrt_cosine\", capsize=.0, color=\"gray\", estimator=\"median\", linestyle=\"--\") #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", - " sns.pointplot(ax=ax, data=C, x=pd.cut(C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")]['tanimoto3'], bins=[x/10.0 for x in list(range(2,11,1))]), y=\"ice_sqrt_cosine\", capsize=.0, color=lightpink_hex, estimator=\"median\") #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + " sns.pointplot(\n", + " ax=ax,\n", + " data=C,\n", + " x=pd.cut(\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", + " \"tanimoto3\"\n", + " ],\n", + " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", + " ),\n", + " y=\"spectral_sqrt_cosine\",\n", + " capsize=0.0,\n", + " color=lightblue_hex,\n", + " estimator=\"median\",\n", + " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + " sns.pointplot(\n", + " ax=ax,\n", + " data=C,\n", + " x=pd.cut(\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", + " \"tanimoto3\"\n", + " ],\n", + " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", + " ),\n", + " y=\"cfm_sqrt_cosine\",\n", + " capsize=0.0,\n", + " color=\"gray\",\n", + " estimator=\"median\",\n", + " linestyle=\"--\",\n", + " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + " sns.pointplot(\n", + " ax=ax,\n", + " data=C,\n", + " x=pd.cut(\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", + " \"tanimoto3\"\n", + " ],\n", + " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", + " ),\n", + " y=\"ice_sqrt_cosine\",\n", + " capsize=0.0,\n", + " color=lightpink_hex,\n", + " estimator=\"median\",\n", + " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " plt.ylim([0.4, 1])\n", - " custom_lines = [plt.Line2D([0], [0], color=lightblue_hex, linestyle='-', marker=\"o\"), plt.Line2D([0], [0], color='gray', linestyle='--', marker=\"o\"), plt.Line2D([0], [0], color=lightpink_hex, linestyle='-', marker=\"o\")]\n", - " plt.legend(custom_lines, [\"Fiora\", \"CFM-ID\", \"ICEBERG\"], title=\"Software\", loc=\"upper left\")\n", - " \n", + " custom_lines = [\n", + " plt.Line2D([0], [0], color=lightblue_hex, linestyle=\"-\", marker=\"o\"),\n", + " plt.Line2D([0], [0], color=\"gray\", linestyle=\"--\", marker=\"o\"),\n", + " plt.Line2D([0], [0], color=lightpink_hex, linestyle=\"-\", marker=\"o\"),\n", + " ]\n", + " plt.legend(\n", + " custom_lines, [\"Fiora\", \"CFM-ID\", \"ICEBERG\"], title=\"Software\", loc=\"upper left\"\n", + " )\n", + "\n", " plt.ylabel(\"Cosine similarity\")\n", " plt.xlabel(\"Tanimoto similarity\")\n", " # fig.savefig(f\"{home}/images/paper/tanimoto.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -19324,34 +20327,83 @@ "datasets = [\"Test split\", \"MSnLib\", \"CASMI 16\", \"CASMI 22\"]\n", "score_func = \"spectral_sqrt_cosine\"\n", "if calc_tanimoto:\n", - " fig, axs = plt.subplots(2,1, figsize=(12, 9), height_ratios=[1,4], sharex=True)\n", + " fig, axs = plt.subplots(2, 1, figsize=(12, 9), height_ratios=[1, 4], sharex=True)\n", " plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08, hspace=0.08)\n", "\n", - " binned_data = pd.cut(C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)][\"tanimoto3\"], bins=[x/10.0 for x in list(range(2,11,1))]).dropna()\n", - " grouped_data = C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)].groupby(binned_data)[\"group_id\"].nunique().reset_index(name=\"unique_group_ids\")\n", + " binned_data = pd.cut(\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)][\"tanimoto3\"],\n", + " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", + " ).dropna()\n", + " grouped_data = (\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)]\n", + " .groupby(binned_data)[\"group_id\"]\n", + " .nunique()\n", + " .reset_index(name=\"unique_group_ids\")\n", + " )\n", "\n", - " sns.barplot(ax=axs[0], x=binned_data.cat.categories, y=\"unique_group_ids\", data=grouped_data, linewidth=1.5, edgecolor=\"black\")\n", + " sns.barplot(\n", + " ax=axs[0],\n", + " x=binned_data.cat.categories,\n", + " y=\"unique_group_ids\",\n", + " data=grouped_data,\n", + " linewidth=1.5,\n", + " edgecolor=\"black\",\n", + " )\n", "\n", " for i in range(len(axs[0].containers)):\n", " axs[0].bar_label(axs[0].containers[i], fontsize=12)\n", " axs[0].set_ylabel(\"\", fontsize=14)\n", " axs[0].set_xlabel(\"\", fontsize=14)\n", - " axs[0].spines['top'].set_visible(False)\n", - " axs[0].spines['right'].set_visible(True)\n", - " axs[0].tick_params(axis='y', labelsize=12) \n", + " axs[0].spines[\"top\"].set_visible(False)\n", + " axs[0].spines[\"right\"].set_visible(True)\n", + " axs[0].tick_params(axis=\"y\", labelsize=12)\n", " adjust_bar_widths(axs[0], 0.5)\n", - " err = ('ci', 95)\n", - " sns.pointplot(ax=axs[1], data=C, x=binned_data, y=score_func, capsize=.0, color=lightblue_hex, estimator=\"median\", errorbar=err) #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", - " sns.pointplot(ax=axs[1], data=C, x=binned_data, y=score_func.replace(\"spectral\", \"cfm\"), capsize=.0, color=\"gray\", estimator=\"median\", linestyle=\"--\", errorbar=err) #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", - " sns.pointplot(ax=axs[1], data=C, x=binned_data, y=score_func.replace(\"spectral\", \"ice\"), capsize=.0, color=lightpink_hex, estimator=\"median\", errorbar=err) #multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + " err = (\"ci\", 95)\n", + " sns.pointplot(\n", + " ax=axs[1],\n", + " data=C,\n", + " x=binned_data,\n", + " y=score_func,\n", + " capsize=0.0,\n", + " color=lightblue_hex,\n", + " estimator=\"median\",\n", + " errorbar=err,\n", + " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + " sns.pointplot(\n", + " ax=axs[1],\n", + " data=C,\n", + " x=binned_data,\n", + " y=score_func.replace(\"spectral\", \"cfm\"),\n", + " capsize=0.0,\n", + " color=\"gray\",\n", + " estimator=\"median\",\n", + " linestyle=\"--\",\n", + " errorbar=err,\n", + " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", + " sns.pointplot(\n", + " ax=axs[1],\n", + " data=C,\n", + " x=binned_data,\n", + " y=score_func.replace(\"spectral\", \"ice\"),\n", + " capsize=0.0,\n", + " color=lightpink_hex,\n", + " estimator=\"median\",\n", + " errorbar=err,\n", + " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " plt.ylim([0.375, 1])\n", - " custom_lines = [plt.Line2D([0], [0], color=lightblue_hex, linestyle='-', marker=\"o\"), plt.Line2D([0], [0], color='gray', linestyle='--', marker=\"o\"), plt.Line2D([0], [0], color=lightpink_hex, linestyle='-', marker=\"o\")]\n", - " plt.legend(custom_lines, [\"Fiora\", \"CFM-ID\", \"ICEBERG\"], title=\"Software\", loc=\"upper left\")\n", - " \n", + " custom_lines = [\n", + " plt.Line2D([0], [0], color=lightblue_hex, linestyle=\"-\", marker=\"o\"),\n", + " plt.Line2D([0], [0], color=\"gray\", linestyle=\"--\", marker=\"o\"),\n", + " plt.Line2D([0], [0], color=lightpink_hex, linestyle=\"-\", marker=\"o\"),\n", + " ]\n", + " plt.legend(\n", + " custom_lines, [\"Fiora\", \"CFM-ID\", \"ICEBERG\"], title=\"Software\", loc=\"upper left\"\n", + " )\n", + "\n", " plt.ylabel(\"Cosine similarity\")\n", " plt.xlabel(\"Tanimoto similarity\")\n", - " axs[1].spines['top'].set_visible(True)\n", - " axs[1].tick_params(axis='y', labelsize=12) \n", + " axs[1].spines[\"top\"].set_visible(True)\n", + " axs[1].tick_params(axis=\"y\", labelsize=12)\n", "\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_+hist.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_+hist.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -19380,21 +20432,29 @@ "source": [ "# Filter the DataFrame based on Precursor_type and Dataset\n", "C = C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)]\n", - "score_cols = [\"spectral_sqrt_cosine_wo_prec\", \"ice_sqrt_cosine_wo_prec\", \"cfm_sqrt_cosine_wo_prec\"]#[\"spectral_sqrt_cosine_wo_prec\", \"ice_sqrt_cosine_wo_prec\", \"cfm_sqrt_cosine_wo_prec\"]\n", + "score_cols = [\n", + " \"spectral_sqrt_cosine_wo_prec\",\n", + " \"ice_sqrt_cosine_wo_prec\",\n", + " \"cfm_sqrt_cosine_wo_prec\",\n", + "] # [\"spectral_sqrt_cosine_wo_prec\", \"ice_sqrt_cosine_wo_prec\", \"cfm_sqrt_cosine_wo_prec\"]\n", "\n", "# Melt the DataFrame to explode the scores into a long format\n", - "melted_df = pd.melt(C, \n", - " id_vars=[\"tanimoto3\", \"Precursor_type\", \"Dataset\"], \n", - " value_vars=score_cols, \n", - " var_name=\"Software\", \n", - " value_name=\"Score\")\n", + "melted_df = pd.melt(\n", + " C,\n", + " id_vars=[\"tanimoto3\", \"Precursor_type\", \"Dataset\"],\n", + " value_vars=score_cols,\n", + " var_name=\"Software\",\n", + " value_name=\"Score\",\n", + ")\n", "\n", "# Map the software names to more descriptive names\n", - "melted_df[\"Software\"] = melted_df[\"Software\"].map({\n", - " \"spectral_sqrt_cosine_wo_prec\": \"Fiora\", #_wo_prec\n", - " \"ice_sqrt_cosine_wo_prec\": \"ICEBERG\",\n", - " \"cfm_sqrt_cosine_wo_prec\": \"CFM-ID\"\n", - "})\n", + "melted_df[\"Software\"] = melted_df[\"Software\"].map(\n", + " {\n", + " \"spectral_sqrt_cosine_wo_prec\": \"Fiora\", # _wo_prec\n", + " \"ice_sqrt_cosine_wo_prec\": \"ICEBERG\",\n", + " \"cfm_sqrt_cosine_wo_prec\": \"CFM-ID\",\n", + " }\n", + ")\n", "\n", "# Now melted_df contains the new format\n", "print(melted_df.head())" @@ -19427,66 +20487,94 @@ " fig, axs = plt.subplots(2, 1, figsize=(12, 9), height_ratios=[1, 4], sharex=True)\n", " plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08, hspace=0.08)\n", "\n", - " binned_data = pd.cut(C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)][\"tanimoto3\"], bins=[x/10.0 for x in list(range(2,11,1))]).dropna()\n", - " grouped_data = C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)].groupby(binned_data)[\"group_id\"].nunique().reset_index(name=\"unique_group_ids\")\n", + " binned_data = pd.cut(\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)][\"tanimoto3\"],\n", + " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", + " ).dropna()\n", + " grouped_data = (\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)]\n", + " .groupby(binned_data)[\"group_id\"]\n", + " .nunique()\n", + " .reset_index(name=\"unique_group_ids\")\n", + " )\n", "\n", - " sns.barplot(ax=axs[0], x=binned_data.cat.categories, y=\"unique_group_ids\", data=grouped_data, linewidth=1.5, edgecolor=\"black\")\n", + " sns.barplot(\n", + " ax=axs[0],\n", + " x=binned_data.cat.categories,\n", + " y=\"unique_group_ids\",\n", + " data=grouped_data,\n", + " linewidth=1.5,\n", + " edgecolor=\"black\",\n", + " )\n", "\n", " # Prepare binned data from melted_df\n", - " melted_df = melted_df[melted_df[\"Dataset\"].isin(datasets)] # Use only relevant datasets\n", - " binned_data = pd.cut(melted_df[\"tanimoto3\"], bins=[x/10.0 for x in range(2, 11)]).dropna()\n", + " melted_df = melted_df[\n", + " melted_df[\"Dataset\"].isin(datasets)\n", + " ] # Use only relevant datasets\n", + " binned_data = pd.cut(\n", + " melted_df[\"tanimoto3\"], bins=[x / 10.0 for x in range(2, 11)]\n", + " ).dropna()\n", "\n", " # Add the binned data as a new column to melted_df\n", - " melted_df['Binned'] = binned_data\n", + " melted_df[\"Binned\"] = binned_data\n", "\n", " # Group for bar plot\n", - " grouped_data = melted_df.groupby('Binned')['Software'].nunique().reset_index(name=\"unique_group_ids\")\n", + " grouped_data = (\n", + " melted_df.groupby(\"Binned\")[\"Software\"]\n", + " .nunique()\n", + " .reset_index(name=\"unique_group_ids\")\n", + " )\n", "\n", " # Bar plot\n", - " #sns.barplot(ax=axs[0], x='Binned', y='unique_group_ids', data=grouped_data, linewidth=1.5, edgecolor=\"black\")\n", - " \n", - "\n", - "\n", - " \n", + " # sns.barplot(ax=axs[0], x='Binned', y='unique_group_ids', data=grouped_data, linewidth=1.5, edgecolor=\"black\")\n", "\n", " for i in range(len(axs[0].containers)):\n", " axs[0].bar_label(axs[0].containers[i], fontsize=12)\n", " axs[0].set_ylabel(\"\", fontsize=14)\n", " axs[0].set_xlabel(\"\", fontsize=14)\n", - " axs[0].spines['top'].set_visible(False)\n", - " axs[0].spines['right'].set_visible(True)\n", - " axs[0].tick_params(axis='y', labelsize=12) \n", + " axs[0].spines[\"top\"].set_visible(False)\n", + " axs[0].spines[\"right\"].set_visible(True)\n", + " axs[0].tick_params(axis=\"y\", labelsize=12)\n", " adjust_bar_widths(axs[0], 0.5)\n", "\n", " err = (\"pi\", 50) # Error bar type\n", "\n", " # Point plots with hue for Software\n", - " sns.pointplot(ax=axs[1], data=melted_df, x='Binned', y='Score', hue='Software', capsize=.0, linestyle=[\"-\", \"-\", \"--\"], palette=[lightblue_hex, lightpink_hex, \"gray\"], estimator=\"median\", errorbar=err, dodge=0.25)\n", + " sns.pointplot(\n", + " ax=axs[1],\n", + " data=melted_df,\n", + " x=\"Binned\",\n", + " y=\"Score\",\n", + " hue=\"Software\",\n", + " capsize=0.0,\n", + " linestyle=[\"-\", \"-\", \"--\"],\n", + " palette=[lightblue_hex, lightpink_hex, \"gray\"],\n", + " estimator=\"median\",\n", + " errorbar=err,\n", + " dodge=0.25,\n", + " )\n", " plt.ylim([0.2, 1])\n", "\n", - "\n", " plt.legend(loc=\"upper left\")\n", - " \n", + "\n", " plt.ylabel(\"Cosine similarity\")\n", " plt.xlabel(\"Tanimoto similarity\")\n", - " axs[1].spines['top'].set_visible(True)\n", - " axs[1].tick_params(axis='both', labelsize=13)\n", - " plt.rc('axes', labelsize=14)\n", - " plt.rc('legend', fontsize=14)\n", + " axs[1].spines[\"top\"].set_visible(True)\n", + " axs[1].tick_params(axis=\"both\", labelsize=13)\n", + " plt.rc(\"axes\", labelsize=14)\n", + " plt.rc(\"legend\", fontsize=14)\n", "\n", - " \n", " # Count the number of samples in each Binned group\n", - " sample_counts = melted_df.groupby('Binned').size()\n", + " sample_counts = melted_df.groupby(\"Binned\").size()\n", "\n", " # Add annotations for the sample counts\n", " for idx, count in enumerate(sample_counts):\n", - " axs[1].text(idx, 0.21, f'n={count}', ha='center', va='bottom', fontsize=13)\n", - " \n", - " \n", + " axs[1].text(idx, 0.21, f\"n={count}\", ha=\"center\", va=\"bottom\", fontsize=13)\n", + "\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_+hist_wop_iqr.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_+hist_wop_iqr.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_+hist_wop_iqr.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - " plt.show()\n" + " plt.show()" ] }, { @@ -19505,32 +20593,46 @@ "outputs": [], "source": [ "if calc_tanimoto:\n", - "\n", - " fig, axs = plt.subplots(1,1, figsize=(12, 6))\n", - "\n", + " fig, axs = plt.subplots(1, 1, figsize=(12, 6))\n", "\n", " # Bin the data based on 'tanimoto3'\n", - " binned_data = pd.cut(C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)][\"tanimoto3\"], \n", - " bins=[x/10.0 for x in range(2, 11, 1)])\n", + " binned_data = pd.cut(\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)][\"tanimoto3\"],\n", + " bins=[x / 10.0 for x in range(2, 11, 1)],\n", + " )\n", "\n", " # Group by both binned_data and 'Dataset', then count unique 'group_id'\n", - " grouped_data = C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)].groupby([binned_data, \"Dataset\"])[\"group_id\"].nunique().reset_index(name=\"unique_group_ids\")\n", + " grouped_data = (\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)]\n", + " .groupby([binned_data, \"Dataset\"])[\"group_id\"]\n", + " .nunique()\n", + " .reset_index(name=\"unique_group_ids\")\n", + " )\n", "\n", - " sns.barplot(x=\"tanimoto3\", y=\"unique_group_ids\", hue=\"Dataset\", data=grouped_data, palette=\"tab10\", hue_order=datasets, edgecolor=\"black\", gap=0.15)\n", + " sns.barplot(\n", + " x=\"tanimoto3\",\n", + " y=\"unique_group_ids\",\n", + " hue=\"Dataset\",\n", + " data=grouped_data,\n", + " palette=\"tab10\",\n", + " hue_order=datasets,\n", + " edgecolor=\"black\",\n", + " gap=0.15,\n", + " )\n", "\n", " plt.xlabel(\"Tanimoto similarity\")\n", " plt.ylabel(\"Number of compounds\")\n", - " #plt.xticks(rotation=45)\n", - " plt.legend(title=\"Dataset\",loc='upper right')\n", + " # plt.xticks(rotation=45)\n", + " plt.legend(title=\"Dataset\", loc=\"upper right\")\n", " plt.tight_layout()\n", - " axs.tick_params(axis='both', labelsize=13)\n", - " plt.rc('axes', labelsize=14)\n", - " plt.rc('legend', fontsize=14)\n", + " axs.tick_params(axis=\"both\", labelsize=13)\n", + " plt.rc(\"axes\", labelsize=14)\n", + " plt.rc(\"legend\", fontsize=14)\n", "\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_distribution.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_distribution.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_distribution.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - " plt.show()\n" + " plt.show()" ] }, { @@ -19549,10 +20651,32 @@ "outputs": [], "source": [ "if calc_tanimoto:\n", - " medians = C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")].groupby(pd.cut(C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")]['tanimoto3'], bins=[x/10.0 for x in list(range(2, 11, 1))]))['spectral_sqrt_cosine'].median()\n", + " medians = (\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")]\n", + " .groupby(\n", + " pd.cut(\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", + " \"tanimoto3\"\n", + " ],\n", + " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", + " )\n", + " )[\"spectral_sqrt_cosine\"]\n", + " .median()\n", + " )\n", " print(f\"Fiora: {medians.min() / medians.max()}\")\n", - " medians = C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")].groupby(pd.cut(C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")]['tanimoto3'], bins=[x/10.0 for x in list(range(2, 11, 1))]))['ice_sqrt_cosine'].median()\n", - " print(f\"ICEBERG: {medians.min() / medians.max()}\")\n" + " medians = (\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")]\n", + " .groupby(\n", + " pd.cut(\n", + " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", + " \"tanimoto3\"\n", + " ],\n", + " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", + " )\n", + " )[\"ice_sqrt_cosine\"]\n", + " .median()\n", + " )\n", + " print(f\"ICEBERG: {medians.min() / medians.max()}\")" ] }, { @@ -19578,12 +20702,20 @@ "\n", "C = pd.concat([df_test, df_msnlib_test, df_cas, df_cas22], ignore_index=True)\n", "\n", - "sns.histplot(C, ax=ax, x=\"spectral_sqrt_cosine\", hue=\"Dataset\", linewidth=1, multiple=\"stack\", edgecolor=\"black\")\n", + "sns.histplot(\n", + " C,\n", + " ax=ax,\n", + " x=\"spectral_sqrt_cosine\",\n", + " hue=\"Dataset\",\n", + " linewidth=1,\n", + " multiple=\"stack\",\n", + " edgecolor=\"black\",\n", + ")\n", "ax.set_xlabel(\"Spectral cosine similarity\")\n", "\n", - "plt.rc('axes', labelsize=14)\n", - "plt.rc('legend', fontsize=14)\n", - "ax.tick_params(axis='both', which='major', labelsize=13)\n", + "plt.rc(\"axes\", labelsize=14)\n", + "plt.rc(\"legend\", fontsize=14)\n", + "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", "# fig.savefig(f\"{home}/images/paper/histplot_cosine.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/histplot_cosine.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/histplot_cosine.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -19617,7 +20749,14 @@ } ], "source": [ - "fig, axs = plt.subplots(2, 1, figsize=(12, 12), gridspec_kw={'height_ratios': [1, 1]}, sharex=True, sharey=True)\n", + "fig, axs = plt.subplots(\n", + " 2,\n", + " 1,\n", + " figsize=(12, 12),\n", + " gridspec_kw={\"height_ratios\": [1, 1]},\n", + " sharex=True,\n", + " sharey=True,\n", + ")\n", "plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08, hspace=0.1)\n", "\n", "C = pd.concat([df_test, df_msnlib_test, df_cas, df_cas22], ignore_index=True)\n", @@ -19627,15 +20766,19 @@ "sns.kdeplot(C, ax=axs[0], x=\"cfm_sqrt_cosine\", color=\"gray\", linewidth=2)\n", "axs[0].legend(title=\"Default\", labels=[\"Fiora\", \"ICEBERG\", \"CFM-ID\"], loc=\"upper left\")\n", "\n", - "sns.kdeplot(C, ax=axs[1], x=\"spectral_sqrt_cosine_wo_prec\", color=bluepink[0], linewidth=2)\n", + "sns.kdeplot(\n", + " C, ax=axs[1], x=\"spectral_sqrt_cosine_wo_prec\", color=bluepink[0], linewidth=2\n", + ")\n", "sns.kdeplot(C, ax=axs[1], x=\"ice_sqrt_cosine_wo_prec\", color=bluepink[1], linewidth=2)\n", "sns.kdeplot(C, ax=axs[1], x=\"cfm_sqrt_cosine_wo_prec\", color=\"gray\", linewidth=2)\n", - "axs[1].legend(title=\"Without precursor\", labels=[\"Fiora\", \"ICEBERG\", \"CFM-ID\"], loc=\"upper left\")\n", + "axs[1].legend(\n", + " title=\"Without precursor\", labels=[\"Fiora\", \"ICEBERG\", \"CFM-ID\"], loc=\"upper left\"\n", + ")\n", "axs[1].set_xlabel(\"Spectral cosine similarity\")\n", "\n", - "plt.rc('axes', labelsize=14)\n", - "plt.rc('legend', fontsize=14)\n", - "ax.tick_params(axis='both', which='major', labelsize=13)\n", + "plt.rc(\"axes\", labelsize=14)\n", + "plt.rc(\"legend\", fontsize=14)\n", + "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", "# fig.savefig(f\"{home}/images/paper/kdeplots_cosine.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/kdeplots_cosine.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/kdeplots_cosine.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -19648,7 +20791,7 @@ "metadata": {}, "outputs": [], "source": [ - "#C[[\"spectral_sqrt_cosine\", \"ice_sqrt_cosine\", \"cfm_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"ice_sqrt_cosine_wo_prec\", \"cfm_sqrt_cosine_wo_prec\"]].to_excel(f\"{home}/images/paper/SF4.xlsx\")" + "# C[[\"spectral_sqrt_cosine\", \"ice_sqrt_cosine\", \"cfm_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"ice_sqrt_cosine_wo_prec\", \"cfm_sqrt_cosine_wo_prec\"]].to_excel(f\"{home}/images/paper/SF4.xlsx\")" ] }, { @@ -19691,8 +20834,10 @@ " dif = Cx[score].fillna(0.0) - Cx[score_w].fillna(0.0)\n", " abs_avg_dif = np.mean(dif)\n", " rel_avg_dif = np.mean(dif / Cx[score].fillna(0.0))\n", - " \n", - " print(f\"Cosine loss from precursor removal (for {dataset} {p}): \\t{abs_avg_dif:.2f} ({rel_avg_dif*100:2.1f}%)\")" + "\n", + " print(\n", + " f\"Cosine loss from precursor removal (for {dataset} {p}): \\t{abs_avg_dif:.2f} ({rel_avg_dif * 100:2.1f}%)\"\n", + " )" ] }, { @@ -19724,8 +20869,20 @@ ], "source": [ "print(\"Avg loss:\", np.mean(C[score].fillna(0.0) - C[score_w].fillna(0.0)))\n", - "print(\"Avg pos loss:\", np.mean(C[C[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0) - C[C[\"Precursor_type\"] == \"[M+H]+\"][score_w].fillna(0.0)))\n", - "print(\"Avg neg loss:\", np.mean(C[C[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0) - C[C[\"Precursor_type\"] == \"[M-H]-\"][score_w].fillna(0.0)))" + "print(\n", + " \"Avg pos loss:\",\n", + " np.mean(\n", + " C[C[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " - C[C[\"Precursor_type\"] == \"[M+H]+\"][score_w].fillna(0.0)\n", + " ),\n", + ")\n", + "print(\n", + " \"Avg neg loss:\",\n", + " np.mean(\n", + " C[C[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " - C[C[\"Precursor_type\"] == \"[M-H]-\"][score_w].fillna(0.0)\n", + " ),\n", + ")" ] }, { @@ -19754,49 +20911,93 @@ "source": [ "from fiora.visualization.define_colors import set_light_theme, tri_palette\n", "\n", - "fig, axs = plt.subplots(2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True)\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "set_light_theme()\n", "set_light_theme()\n", "\n", - "df_test_unique = df_test.dropna(subset=[\"RETENTIONTIME\"]).drop_duplicates(subset='SMILES', keep='first')\n", - "\n", - "plt.rc('legend', fontsize=20)\n", - "\n", - "#sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=df_test_unique, x=\"RETENTIONTIME\", bw_adjust=0.25, palette=[\"gray\"], fill=True, hue=\"Dataset\", alpha=0.7)#, multiple=\"stack\") #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=df_test_unique, x=\"RETENTIONTIME\", bw_adjust=0.25, color=\"#696969\", linewidth=1.7) #, multiple=\"stack\") #hue=\"Precursor_type\", \n", - "#axs[0].legend(title=\"Dataset\", loc=\"upper right\")\n", - "\n", - "\n", - "\n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", - "\n", + "df_test_unique = df_test.dropna(subset=[\"RETENTIONTIME\"]).drop_duplicates(\n", + " subset=\"SMILES\", keep=\"first\"\n", + ")\n", "\n", - "sns.scatterplot(ax=axs[1], data=df_test_unique, x=\"RETENTIONTIME\", y=\"RT_pred\", color=\"gray\", marker=(4,1,0), s=60, edgecolor=\"None\")#, hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", - "axs[1].set_ylim([2,df_test_unique[\"RETENTIONTIME\"].max() + 0.5])\n", - "axs[1].set_xlim([2,df_test_unique[\"RETENTIONTIME\"].max() + 0.5])\n", + "plt.rc(\"legend\", fontsize=20)\n", + "\n", + "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=df_test_unique,\n", + " x=\"RETENTIONTIME\",\n", + " bw_adjust=0.25,\n", + " palette=[\"gray\"],\n", + " fill=True,\n", + " hue=\"Dataset\",\n", + " alpha=0.7,\n", + ") # , multiple=\"stack\") #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=df_test_unique,\n", + " x=\"RETENTIONTIME\",\n", + " bw_adjust=0.25,\n", + " color=\"#696969\",\n", + " linewidth=1.7,\n", + ") # , multiple=\"stack\") #hue=\"Precursor_type\",\n", + "# axs[0].legend(title=\"Dataset\", loc=\"upper right\")\n", + "\n", + "\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", + "\n", + "\n", + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=df_test_unique,\n", + " x=\"RETENTIONTIME\",\n", + " y=\"RT_pred\",\n", + " color=\"gray\",\n", + " marker=(4, 1, 0),\n", + " s=60,\n", + " edgecolor=\"None\",\n", + ") # , hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", + "axs[1].set_ylim([2, df_test_unique[\"RETENTIONTIME\"].max() + 0.5])\n", + "axs[1].set_xlim([2, df_test_unique[\"RETENTIONTIME\"].max() + 0.5])\n", "axs[1].set_ylabel(\"Predicted retention time (in minutes)\")\n", "axs[1].set_xlabel(\"Observed retention time (in minutes)\")\n", "line = [0, 100]\n", "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", - "sns.lineplot(ax=axs[1], x=line, y=[x + 30/60.0 for x in line], color=\"black\", linestyle='--')\n", - "sns.lineplot(ax=axs[1], x=line, y=[x - 30/60.0 for x in line], color=\"black\", linestyle='--')\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[x + 30 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[x - 30 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", "\n", "# sns.lineplot(ax=axs[1], x=line, y=[1.1*x for x in line], color=\"black\", linestyle='--')\n", "# sns.lineplot(ax=axs[1], x=line, y=[0.9*x for x in line], color=\"black\", linestyle='--')\n", "\n", - "#Text sizes\n", + "# Text sizes\n", "\n", - "axs[0].tick_params(axis='both', labelsize=16)\n", - "axs[1].tick_params(axis='both', labelsize=16)\n", - "plt.rc('axes', labelsize=20)\n", - "plt.rc('legend', fontsize=20)\n", + "axs[0].tick_params(axis=\"both\", labelsize=16)\n", + "axs[1].tick_params(axis=\"both\", labelsize=16)\n", + "plt.rc(\"axes\", labelsize=20)\n", + "plt.rc(\"legend\", fontsize=20)\n", "\n", "\n", - "fig.savefig(f\"{home}/images/paper/rt.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - "fig.savefig(f\"{home}/images/paper/rt.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", + "fig.savefig(\n", + " f\"{home}/images/paper/rt.svg\",\n", + " format=\"svg\",\n", + " dpi=600,\n", + " bbox_inches=\"tight\",\n", + " pad_inches=0.1,\n", + ")\n", + "fig.savefig(\n", + " f\"{home}/images/paper/rt.pdf\",\n", + " format=\"pdf\",\n", + " dpi=600,\n", + " bbox_inches=\"tight\",\n", + " pad_inches=0.1,\n", + ")\n", "# fig.savefig(f\"{home}/images/paper/rt.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", "\n", @@ -19828,10 +21029,17 @@ } ], "source": [ - "\n", "print(\"Pearson Corr Coef:\")\n", - "print(\"GNN PC\", np.corrcoef(df_test_unique[\"RETENTIONTIME\"], df_test_unique[\"RT_pred\"].dropna(), dtype=float)[0,1])\n", - "print(\"GNN R2\", r2_score(df_test_unique[\"RETENTIONTIME\"], df_test_unique[\"RT_pred\"].dropna()))\n" + "print(\n", + " \"GNN PC\",\n", + " np.corrcoef(\n", + " df_test_unique[\"RETENTIONTIME\"], df_test_unique[\"RT_pred\"].dropna(), dtype=float\n", + " )[0, 1],\n", + ")\n", + "print(\n", + " \"GNN R2\",\n", + " r2_score(df_test_unique[\"RETENTIONTIME\"], df_test_unique[\"RT_pred\"].dropna()),\n", + ")" ] }, { @@ -19840,8 +21048,12 @@ "metadata": {}, "outputs": [], "source": [ - "df_train_rt = df_train[~df_train[\"RETENTIONTIME\"].isna()].drop_duplicates(subset='group_id', keep='first')\n", - "df_test_rt = df_test[~df_test[\"RETENTIONTIME\"].isna()].drop_duplicates(subset='group_id', keep='first')" + "df_train_rt = df_train[~df_train[\"RETENTIONTIME\"].isna()].drop_duplicates(\n", + " subset=\"group_id\", keep=\"first\"\n", + ")\n", + "df_test_rt = df_test[~df_test[\"RETENTIONTIME\"].isna()].drop_duplicates(\n", + " subset=\"group_id\", keep=\"first\"\n", + ")" ] }, { @@ -19870,7 +21082,7 @@ "metadata": {}, "outputs": [], "source": [ - "with open(f'{home}/data/metabolites/rt/train.logp') as infile:\n", + "with open(f\"{home}/data/metabolites/rt/train.logp\") as infile:\n", " lines = infile.readlines()\n", "logps = []\n", "for line in lines[11:]:\n", @@ -19885,7 +21097,7 @@ "metadata": {}, "outputs": [], "source": [ - "with open(f'{home}/data/metabolites/rt/test.logp') as infile:\n", + "with open(f\"{home}/data/metabolites/rt/test.logp\") as infile:\n", " lines = infile.readlines()\n", "logps = []\n", "for line in lines[11:]:\n", @@ -19946,16 +21158,34 @@ } ], "source": [ - "\n", - "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(df_train_rt[\"logp\"], df_train_rt[\"RETENTIONTIME\"])\n", + "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", + " df_train_rt[\"logp\"], df_train_rt[\"RETENTIONTIME\"]\n", + ")\n", "print(\"TEST SPLIT:\\n\")\n", "print(\"Pearson Corr Coef:\")\n", - "print(\"GNN\", np.corrcoef(df_test_rt[\"RETENTIONTIME\"], df_test_rt[\"RT_pred\"].dropna(), dtype=float)[0,1])\n", - "print(\"LR \", np.corrcoef(df_test_rt[\"RETENTIONTIME\"], intercept + slope * df_test_rt[\"logp\"].dropna(), dtype=float)[0,1])\n", + "print(\n", + " \"GNN\",\n", + " np.corrcoef(\n", + " df_test_rt[\"RETENTIONTIME\"], df_test_rt[\"RT_pred\"].dropna(), dtype=float\n", + " )[0, 1],\n", + ")\n", + "print(\n", + " \"LR \",\n", + " np.corrcoef(\n", + " df_test_rt[\"RETENTIONTIME\"],\n", + " intercept + slope * df_test_rt[\"logp\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", "\n", "print(\"R2\")\n", "print(\"GNN\", r2_score(df_test_rt[\"RETENTIONTIME\"], df_test_rt[\"RT_pred\"].dropna()))\n", - "print(\"LR \", r2_score(df_test_rt[\"RETENTIONTIME\"], intercept + slope * df_test_rt[\"logp\"].dropna()))" + "print(\n", + " \"LR \",\n", + " r2_score(\n", + " df_test_rt[\"RETENTIONTIME\"], intercept + slope * df_test_rt[\"logp\"].dropna()\n", + " ),\n", + ")" ] }, { @@ -19991,36 +21221,87 @@ ], "source": [ "# TODO NEXT UP!!\n", - "fig, axs = plt.subplots(2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True)\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", - "df_test_unique = df_test.dropna(subset=[\"CCS\"]).drop_duplicates(subset='SMILES', keep='first')\n", - "df_cas22_unique = df_cas22.dropna(subset=[\"CCS\"]).drop_duplicates(subset='SMILES', keep='first') # Note that with more lenient filters, CCS values might be annotated for CASMI 22\n", - "\n", - "CCS = pd.concat([df_test_unique[[\"CCS\", \"CCS_pred\", \"Dataset\"]], df_cas[[\"CCS\", \"CCS_pred\", \"Dataset\"]], df_cas22_unique[[\"CCS\", \"CCS_pred\", \"Dataset\"]] ], ignore_index=True) #\n", - "\n", - "\n", - "#sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=CCS[CCS[\"Dataset\"] != \"CASMI 22\"], x=\"CCS\", bw_adjust=0.35, color=\"black\", multiple=\"stack\", hue=\"Dataset\", palette=[\"#696969\"] + [\"white\"], linewidth=1.7, fill=False)#, edgecolor=\"lightgray\") #hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=CCS[CCS[\"Dataset\"] != \"CASMI 22\"], x=\"CCS\", bw_adjust=0.35, color=\"black\", multiple=\"stack\", hue=\"Dataset\", palette=[\"gray\"] + [tri_palette[1]], alpha=0.7, fill=True, edgecolor=\"gray\")#, edgecolor=\"lightgray\") #hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", - "\n", - "\n", - "sns.scatterplot(ax=axs[1], data=CCS, x=\"CCS\", y=\"CCS_pred\", hue=\"Dataset\", palette=tri_palette, style=\"Dataset\", markers=[(4, 1, 0), \"v\", \"o\", (4,0,45), \"v\", \"D\"],s = 35, linewidth=.0)#, s=50, edgecolor=\"white\")#, linewidth=.0)#, color=\"blue\", edgecolor=\"blue\")#, \n", - "axs[1].set_ylim([df_test_unique[\"CCS\"].min() - 30,df_test_unique[\"CCS\"].max() + 5])\n", - "axs[1].set_xlim([df_test_unique[\"CCS\"].min() - 30,df_test_unique[\"CCS\"].max() + 5])\n", + "df_test_unique = df_test.dropna(subset=[\"CCS\"]).drop_duplicates(\n", + " subset=\"SMILES\", keep=\"first\"\n", + ")\n", + "df_cas22_unique = df_cas22.dropna(subset=[\"CCS\"]).drop_duplicates(\n", + " subset=\"SMILES\", keep=\"first\"\n", + ") # Note that with more lenient filters, CCS values might be annotated for CASMI 22\n", + "\n", + "CCS = pd.concat(\n", + " [\n", + " df_test_unique[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", + " df_cas[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", + " df_cas22_unique[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", + " ],\n", + " ignore_index=True,\n", + ") #\n", + "\n", + "\n", + "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=CCS[CCS[\"Dataset\"] != \"CASMI 22\"],\n", + " x=\"CCS\",\n", + " bw_adjust=0.35,\n", + " color=\"black\",\n", + " multiple=\"stack\",\n", + " hue=\"Dataset\",\n", + " palette=[\"#696969\"] + [\"white\"],\n", + " linewidth=1.7,\n", + " fill=False,\n", + ") # , edgecolor=\"lightgray\") #hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=CCS[CCS[\"Dataset\"] != \"CASMI 22\"],\n", + " x=\"CCS\",\n", + " bw_adjust=0.35,\n", + " color=\"black\",\n", + " multiple=\"stack\",\n", + " hue=\"Dataset\",\n", + " palette=[\"gray\"] + [tri_palette[1]],\n", + " alpha=0.7,\n", + " fill=True,\n", + " edgecolor=\"gray\",\n", + ") # , edgecolor=\"lightgray\") #hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", + "\n", + "\n", + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=CCS,\n", + " x=\"CCS\",\n", + " y=\"CCS_pred\",\n", + " hue=\"Dataset\",\n", + " palette=tri_palette,\n", + " style=\"Dataset\",\n", + " markers=[(4, 1, 0), \"v\", \"o\", (4, 0, 45), \"v\", \"D\"],\n", + " s=35,\n", + " linewidth=0.0,\n", + ") # , s=50, edgecolor=\"white\")#, linewidth=.0)#, color=\"blue\", edgecolor=\"blue\")#,\n", + "axs[1].set_ylim([df_test_unique[\"CCS\"].min() - 30, df_test_unique[\"CCS\"].max() + 5])\n", + "axs[1].set_xlim([df_test_unique[\"CCS\"].min() - 30, df_test_unique[\"CCS\"].max() + 5])\n", "axs[1].set_ylabel(\"Predicted CCS\")\n", "axs[1].set_xlabel(\"Observed CCS\")\n", - "line=[df_test_unique[\"CCS\"].min() - 30,df_test_unique[\"CCS\"].max() + 5]\n", + "line = [df_test_unique[\"CCS\"].min() - 30, df_test_unique[\"CCS\"].max() + 5]\n", "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", - "sns.lineplot(ax=axs[1], x=line, y=[1.1*x for x in line], color=\"black\", linestyle='--')\n", - "sns.lineplot(ax=axs[1], x=line, y=[0.9*x for x in line], color=\"black\", linestyle='--')\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[1.1 * x for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", + "sns.lineplot(\n", + " ax=axs[1], x=line, y=[0.9 * x for x in line], color=\"black\", linestyle=\"--\"\n", + ")\n", "\n", - "axs[0].tick_params(axis='both', labelsize=16)\n", - "axs[1].tick_params(axis='both', labelsize=16)\n", - "plt.rc('axes', labelsize=20)\n", - "plt.rc('legend', fontsize=20)\n", + "axs[0].tick_params(axis=\"both\", labelsize=16)\n", + "axs[1].tick_params(axis=\"both\", labelsize=16)\n", + "plt.rc(\"axes\", labelsize=20)\n", + "plt.rc(\"legend\", fontsize=20)\n", "\n", "# fig.savefig(f\"{home}/images/paper/ccs.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/ccs.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20068,24 +21349,80 @@ } ], "source": [ - "\n", - "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(df_train.dropna(subset=[\"CCS\"])[\"PRECURSORMZ\"], df_train.dropna(subset=[\"CCS\"])[\"CCS\"])\n", + "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", + " df_train.dropna(subset=[\"CCS\"])[\"PRECURSORMZ\"],\n", + " df_train.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + ")\n", "print(\"TEST SPLIT:\\n\")\n", "print(\"Pearson Corr Coef:\")\n", - "print(\"GNN\", np.corrcoef(df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"], df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(), dtype=float)[0,1])\n", - "print(\"LR \", np.corrcoef(df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"], intercept + slope *df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(), dtype=float)[0,1])\n", + "print(\n", + " \"GNN\",\n", + " np.corrcoef(\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", + "print(\n", + " \"LR \",\n", + " np.corrcoef(\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " intercept\n", + " + slope * df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", "\n", "print(\"R2\")\n", - "print(\"GNN\", r2_score(df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"], df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna()))\n", - "print(\"LR \", r2_score(df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"], intercept + slope *df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna()))\n", + "print(\n", + " \"GNN\",\n", + " r2_score(\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " ),\n", + ")\n", + "print(\n", + " \"LR \",\n", + " r2_score(\n", + " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " intercept\n", + " + slope * df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " ),\n", + ")\n", "\n", "print(\"---------------\\n\\nCASMI-16:\\n\")\n", "print(\"Pearson Corr Coef:\")\n", - "print(\"GNN\", np.corrcoef(df_cas.dropna(subset=[\"CCS\"])[\"CCS\"], df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(), dtype=float)[0,1])\n", - "print(\"LR \", np.corrcoef(df_cas.dropna(subset=[\"CCS\"])[\"CCS\"], intercept + slope *df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(), dtype=float)[0,1])\n", + "print(\n", + " \"GNN\",\n", + " np.corrcoef(\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", + "print(\n", + " \"LR \",\n", + " np.corrcoef(\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " intercept + slope * df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " dtype=float,\n", + " )[0, 1],\n", + ")\n", "print(\"R2\")\n", - "print(\"GNN\", r2_score(df_cas.dropna(subset=[\"CCS\"])[\"CCS\"], df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna()))\n", - "print(\"LR \", r2_score(df_cas.dropna(subset=[\"CCS\"])[\"CCS\"], intercept + slope *df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna()))" + "print(\n", + " \"GNN\",\n", + " r2_score(\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " ),\n", + ")\n", + "print(\n", + " \"LR \",\n", + " r2_score(\n", + " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " intercept + slope * df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " ),\n", + ")" ] }, { @@ -20095,9 +21432,9 @@ "outputs": [], "source": [ "# Load coverage into dataframe\n", - "df_test[\"coverage\"] = df_test[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"]) \n", - "df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"]) \n", - "df_cas22[\"coverage\"] = df_cas22[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"]) " + "df_test[\"coverage\"] = df_test[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", + "df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", + "df_cas22[\"coverage\"] = df_cas22[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])" ] }, { @@ -20120,7 +21457,17 @@ "CAT = pd.concat([df_test, df_cas, df_cas22])\n", "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", "set_light_theme()\n", - "sns.histplot(ax=ax, data=CAT, x=\"CE\", hue=\"Dataset\", multiple=\"stack\", palette=tri_palette, stat=\"density\", common_norm=False, kde=False)\n", + "sns.histplot(\n", + " ax=ax,\n", + " data=CAT,\n", + " x=\"CE\",\n", + " hue=\"Dataset\",\n", + " multiple=\"stack\",\n", + " palette=tri_palette,\n", + " stat=\"density\",\n", + " common_norm=False,\n", + " kde=False,\n", + ")\n", "plt.xlabel(\"Collision energy\")\n", "plt.show()" ] @@ -20198,12 +21545,24 @@ ], "source": [ "set_light_theme()\n", - "scores = [\"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"spectral_refl_cosine\", \"steins_cosine\"]\n", + "scores = [\n", + " \"spectral_cosine\",\n", + " \"spectral_sqrt_cosine\",\n", + " \"spectral_sqrt_cosine_wo_prec\",\n", + " \"spectral_refl_cosine\",\n", + " \"steins_cosine\",\n", + "]\n", "biases = [s.replace(\"cosine\", \"bias\") for s in scores]\n", - "#labels = [\"cosine\", r\"cosine $\\sqrt{int}$ \", r\"$\\sqrt{i}$ cosine w/o precursor\", r\"$\\sqrt{i}$ reflection cosine\", r\"$\\sqrt{i}$ mass-weighted cosine\"]\n", - "\n", - "#labels = [\"cosine\", \"sqrt cosine\", \"sqrt cosine w/o precursor\", \"sqrt reflection cosine\", \"sqrt mass-weighted cosine\"]\n", - "labels = [\"Standard\\ncosine\", \"Square root\\nintensities\", \"Square root\\nintensities\\nw/o precursor\", \"Square root\\nintensities\\nreflection score\", \"Square root\\nintensities\\nscaled by m/z\"]\n", + "# labels = [\"cosine\", r\"cosine $\\sqrt{int}$ \", r\"$\\sqrt{i}$ cosine w/o precursor\", r\"$\\sqrt{i}$ reflection cosine\", r\"$\\sqrt{i}$ mass-weighted cosine\"]\n", + "\n", + "# labels = [\"cosine\", \"sqrt cosine\", \"sqrt cosine w/o precursor\", \"sqrt reflection cosine\", \"sqrt mass-weighted cosine\"]\n", + "labels = [\n", + " \"Standard\\ncosine\",\n", + " \"Square root\\nintensities\",\n", + " \"Square root\\nintensities\\nw/o precursor\",\n", + " \"Square root\\nintensities\\nreflection score\",\n", + " \"Square root\\nintensities\\nscaled by m/z\",\n", + "]\n", "S = df_test.melt(id_vars=\"Name\", var_name=\"Score\", value_vars=scores)\n", "B = df_test.melt(id_vars=\"Name\", var_name=\"Score\", value_vars=biases)\n", "\n", @@ -20212,24 +21571,64 @@ "plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08, hspace=0.08)\n", "\n", "\n", - "highlight_2=[sns.color_palette(\"colorblind\")[7], sns.color_palette(\"colorblind\")[9], sns.color_palette(\"colorblind\")[7], sns.color_palette(\"colorblind\")[7], sns.color_palette(\"colorblind\")[7]]\n", - "sns.boxplot(ax=axs[0], data=S, y=\"value\", x=\"Score\", order=scores, hue=\"Score\", palette=highlight_2, showfliers=False)\n", - "sns.boxplot(ax=axs[1], data=B, y=\"value\", x=\"Score\", order=biases, hue=\"Score\", palette=highlight_2, showfliers=False)\n", + "highlight_2 = [\n", + " sns.color_palette(\"colorblind\")[7],\n", + " sns.color_palette(\"colorblind\")[9],\n", + " sns.color_palette(\"colorblind\")[7],\n", + " sns.color_palette(\"colorblind\")[7],\n", + " sns.color_palette(\"colorblind\")[7],\n", + "]\n", + "sns.boxplot(\n", + " ax=axs[0],\n", + " data=S,\n", + " y=\"value\",\n", + " x=\"Score\",\n", + " order=scores,\n", + " hue=\"Score\",\n", + " palette=highlight_2,\n", + " showfliers=False,\n", + ")\n", + "sns.boxplot(\n", + " ax=axs[1],\n", + " data=B,\n", + " y=\"value\",\n", + " x=\"Score\",\n", + " order=biases,\n", + " hue=\"Score\",\n", + " palette=highlight_2,\n", + " showfliers=False,\n", + ")\n", "axs[0].set_xticklabels(\"\")\n", "axs[0].set_xlabel(\"\")\n", "axs[0].set_ylabel(\"Similarity\", fontsize=14)\n", "axs[1].set_xlabel(\"\")\n", "axs[1].set_xticklabels(labels)\n", "axs[1].set_ylabel(\"Bias\", fontsize=14)\n", - "#axs[1].set_ylim([0, 1])\n", - "#axs[1].axhline(y=B.groupby('Score')['value'].median().min(), xmin=0, xmax=10, color=\"red\", linestyle=\"--\")\n", - "\n", - "plt.rc('axes', labelsize=14)\n", - "plt.rc('legend', fontsize=14)\n", - "axs[0].tick_params(axis='both', which='major', labelsize=13)\n", - "axs[1].tick_params(axis='both', which='major', labelsize=13)\n", - "axs[0].text(0.02, 0.025, f\"n={df_test.shape[0]} for all categories\", transform=axs[0].transAxes, fontsize=13, va='bottom', ha='left')\n", - "axs[1].text(0.02, 0.025, f\"n={df_test.shape[0]} for all categories\", transform=axs[1].transAxes, fontsize=13, va='bottom', ha='left')\n", + "# axs[1].set_ylim([0, 1])\n", + "# axs[1].axhline(y=B.groupby('Score')['value'].median().min(), xmin=0, xmax=10, color=\"red\", linestyle=\"--\")\n", + "\n", + "plt.rc(\"axes\", labelsize=14)\n", + "plt.rc(\"legend\", fontsize=14)\n", + "axs[0].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "axs[1].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "axs[0].text(\n", + " 0.02,\n", + " 0.025,\n", + " f\"n={df_test.shape[0]} for all categories\",\n", + " transform=axs[0].transAxes,\n", + " fontsize=13,\n", + " va=\"bottom\",\n", + " ha=\"left\",\n", + ")\n", + "axs[1].text(\n", + " 0.02,\n", + " 0.025,\n", + " f\"n={df_test.shape[0]} for all categories\",\n", + " transform=axs[1].transAxes,\n", + " fontsize=13,\n", + " va=\"bottom\",\n", + " ha=\"left\",\n", + ")\n", "\n", "\n", "# fig.savefig(f\"{home}/images/paper/scores_overview.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20237,8 +21636,8 @@ "# fig.savefig(f\"{home}/images/paper/scores_overview.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", "\n", - "#plt.xticks(rotation=90)\n", - "plt.show()\n" + "# plt.xticks(rotation=90)\n", + "plt.show()" ] }, { @@ -20247,7 +21646,7 @@ "metadata": {}, "outputs": [], "source": [ - "#pd.concat([S, B]).to_excel(f\"{home}/images/paper/SF20.xlsx\")" + "# pd.concat([S, B]).to_excel(f\"{home}/images/paper/SF20.xlsx\")" ] }, { @@ -20316,33 +21715,64 @@ ], "source": [ "set_light_theme()\n", - "fig, ax = plt.subplots(1,1, figsize=(12, 6))\n", - "score_columns = [\"spectral_sqrt_cosine_20\", \"spectral_sqrt_cosine_35\", \"spectral_sqrt_cosine_50\", \"spectral_sqrt_cosine\"]\n", + "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", + "score_columns = [\n", + " \"spectral_sqrt_cosine_20\",\n", + " \"spectral_sqrt_cosine_35\",\n", + " \"spectral_sqrt_cosine_50\",\n", + " \"spectral_sqrt_cosine\",\n", + "]\n", "\n", "df_filtered = df_cas[score_columns]\n", "\n", - "df_melted = df_filtered.melt(var_name='Score_Type')\n", - "custom_labels = ['20', '35', '50', '20/35/50']\n", + "df_melted = df_filtered.melt(var_name=\"Score_Type\")\n", + "custom_labels = [\"20\", \"35\", \"50\", \"20/35/50\"]\n", "magma = sns.color_palette(\"magma_r\", 4)\n", - "adjusted_last_color = sns.utils.set_hls_values(magma[1], l=min(1.0, magma[-1][2] * 1.2)) # Increase lightness by 20%\n", + "adjusted_last_color = sns.utils.set_hls_values(\n", + " magma[1], l=min(1.0, magma[-1][2] * 1.2)\n", + ") # Increase lightness by 20%\n", "\n", "# Plot boxplots\n", "plt.figure(figsize=(10, 6))\n", "\n", - "ax = sns.boxplot(ax=ax, data=df_melted, x='Score_Type', y='value', hue='Score_Type', dodge=False, showfliers=False, palette=magma[:3] + [magma[1]], linewidth=2)\n", + "ax = sns.boxplot(\n", + " ax=ax,\n", + " data=df_melted,\n", + " x=\"Score_Type\",\n", + " y=\"value\",\n", + " hue=\"Score_Type\",\n", + " dodge=False,\n", + " showfliers=False,\n", + " palette=magma[:3] + [magma[1]],\n", + " linewidth=2,\n", + ")\n", "bars = ax.patches\n", "bars[-1].set_hatch(\".\")\n", - "#bars[-1].set_hatch_linewidth(2)\n", - "\n", - "ax.axhline(df_filtered[\"spectral_sqrt_cosine\"].median(), color='black', linewidth=2, linestyle='dotted', label='Median of spectral_sqrt_cosine')\n", - "ax.set_xlabel('NCE')\n", - "ax.set_ylabel('Cosine similarity')\n", + "# bars[-1].set_hatch_linewidth(2)\n", + "\n", + "ax.axhline(\n", + " df_filtered[\"spectral_sqrt_cosine\"].median(),\n", + " color=\"black\",\n", + " linewidth=2,\n", + " linestyle=\"dotted\",\n", + " label=\"Median of spectral_sqrt_cosine\",\n", + ")\n", + "ax.set_xlabel(\"NCE\")\n", + "ax.set_ylabel(\"Cosine similarity\")\n", "ax.set_xticklabels(custom_labels)\n", "\n", - "plt.rc('axes', labelsize=14)\n", - "plt.rc('legend', fontsize=14)\n", - "ax.tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", - "ax.text(0.02, 0.02, f\"n={df_filtered.shape[0]} for all data points\", transform=ax.transAxes, fontsize=13, va='bottom', ha='left')\n", + "plt.rc(\"axes\", labelsize=14)\n", + "plt.rc(\"legend\", fontsize=14)\n", + "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", + "ax.text(\n", + " 0.02,\n", + " 0.02,\n", + " f\"n={df_filtered.shape[0]} for all data points\",\n", + " transform=ax.transAxes,\n", + " fontsize=13,\n", + " va=\"bottom\",\n", + " ha=\"left\",\n", + ")\n", "\n", "\n", "# fig.savefig(f\"{home}/images/paper/NCE_casmi16.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20455,33 +21885,33 @@ } ], "source": [ - "\n", "def filter_peaks(peaks, fraction=0.8, max_peaks=100, min_mz=120):\n", - " \n", + "\n", " # d = {\n", " # \"mz\": peaks[\"mz\"],\n", " # \"intensity\": peaks[\"intensity\"]\n", " # }\n", - " \n", + "\n", " # high_mz_idx = np.where(np.array(d[\"mz\"]) > min_mz)[0]\n", - " # d[\"mz\"] = d[\"mz\"][high_mz_idx.tolist()] \n", + " # d[\"mz\"] = d[\"mz\"][high_mz_idx.tolist()]\n", " # d[\"intensity\"] = d[\"intensity\"][high_mz_idx.tolist()]\n", - " \n", - " order = np.argsort(peaks[\"intensity\"])[::-1] \n", + "\n", + " order = np.argsort(peaks[\"intensity\"])[::-1]\n", " value = fraction * np.sum(peaks[\"intensity\"])\n", - " num_of_relevant_peaks = min(max_peaks - 1, np.argmax(np.cumsum(np.array(peaks[\"intensity\"])[order]) > value))\n", - " indices = order[:num_of_relevant_peaks+1]\n", - " \n", - " \n", + " num_of_relevant_peaks = min(\n", + " max_peaks - 1, np.argmax(np.cumsum(np.array(peaks[\"intensity\"])[order]) > value)\n", + " )\n", + " indices = order[: num_of_relevant_peaks + 1]\n", + "\n", " d = {\n", " \"mz\": np.array(peaks[\"mz\"])[indices].tolist(),\n", - " \"intensity\": np.array(peaks[\"intensity\"])[indices].tolist()\n", + " \"intensity\": np.array(peaks[\"intensity\"])[indices].tolist(),\n", " }\n", "\n", - "\n", " return d\n", " # return d\n", "\n", + "\n", "print(df_cas[\"peaks\"].iloc[0])\n", "p = filter_peaks(df_cas[\"peaks\"].iloc[0], max_peaks=10)\n", "p" @@ -20497,15 +21927,27 @@ "df_cas[\"num_peaks\"] = df_cas[\"peaks\"].apply(lambda x: len(x[\"mz\"]))\n", "df_test[\"num_peaks\"] = df_test[\"peaks\"].apply(lambda x: len(x[\"mz\"]))\n", "df_msnlib_test[\"num_peaks\"] = df_msnlib_test[\"peaks\"].apply(lambda x: len(x[\"mz\"]))\n", - "df_cas22[\"filtered_peaks\"] = df_cas22[\"peaks\"].apply(lambda x: filter_peaks(x, max_peaks=20, fraction=0.8))\n", - "df_cas[\"filtered_peaks\"] = df_cas[\"peaks\"].apply(lambda x: filter_peaks(x, max_peaks=20,fraction=0.8))\n", - "df_test[\"filtered_peaks\"] = df_test[\"peaks\"].apply(lambda x: filter_peaks(x, max_peaks=20, fraction=0.8))\n", - "df_msnlib_test[\"filtered_peaks\"] = df_msnlib_test[\"peaks\"].apply(lambda x: filter_peaks(x, max_peaks=20, fraction=0.8))\n", - "df_cas22[\"num_filtered_peaks\"] = df_cas22[\"filtered_peaks\"].apply(lambda x: len(x[\"mz\"]))\n", + "df_cas22[\"filtered_peaks\"] = df_cas22[\"peaks\"].apply(\n", + " lambda x: filter_peaks(x, max_peaks=20, fraction=0.8)\n", + ")\n", + "df_cas[\"filtered_peaks\"] = df_cas[\"peaks\"].apply(\n", + " lambda x: filter_peaks(x, max_peaks=20, fraction=0.8)\n", + ")\n", + "df_test[\"filtered_peaks\"] = df_test[\"peaks\"].apply(\n", + " lambda x: filter_peaks(x, max_peaks=20, fraction=0.8)\n", + ")\n", + "df_msnlib_test[\"filtered_peaks\"] = df_msnlib_test[\"peaks\"].apply(\n", + " lambda x: filter_peaks(x, max_peaks=20, fraction=0.8)\n", + ")\n", + "df_cas22[\"num_filtered_peaks\"] = df_cas22[\"filtered_peaks\"].apply(\n", + " lambda x: len(x[\"mz\"])\n", + ")\n", "df_cas[\"num_filtered_peaks\"] = df_cas[\"filtered_peaks\"].apply(lambda x: len(x[\"mz\"]))\n", "df_test[\"num_filtered_peaks\"] = df_test[\"filtered_peaks\"].apply(lambda x: len(x[\"mz\"]))\n", - "df_msnlib_test[\"num_filtered_peaks\"] = df_msnlib_test[\"filtered_peaks\"].apply(lambda x: len(x[\"mz\"]))\n", - "CAT = pd.concat([df_test, df_msnlib_test, df_cas, df_cas22])\n" + "df_msnlib_test[\"num_filtered_peaks\"] = df_msnlib_test[\"filtered_peaks\"].apply(\n", + " lambda x: len(x[\"mz\"])\n", + ")\n", + "CAT = pd.concat([df_test, df_msnlib_test, df_cas, df_cas22])" ] }, { @@ -20538,50 +21980,85 @@ ], "source": [ "reset_matplotlib()\n", - "fig, axs = plt.subplots(4, 4, figsize=(12, 8), sharex='col', sharey=\"col\")\n", - "plt.subplots_adjust(hspace=0.1, wspace=0.20, right=0.95)#, top=0.94, bottom=0.12, right=0.97, left=0)\n", + "fig, axs = plt.subplots(4, 4, figsize=(12, 8), sharex=\"col\", sharey=\"col\")\n", + "plt.subplots_adjust(\n", + " hspace=0.1, wspace=0.20, right=0.95\n", + ") # , top=0.94, bottom=0.12, right=0.97, left=0)\n", "\n", "dataset_names = [\"Test split\", \"MSnLib\", \"CASMI 16\", \"CASMI 22\"]\n", "\n", "\n", "# Loop through each row and set row labels\n", "for i, dataset_name in enumerate(dataset_names):\n", - " axs[i, 0].set_ylabel(dataset_name, rotation=90, labelpad=10, ha='center', va='center', fontsize=12, fontweight='bold')\n", + " axs[i, 0].set_ylabel(\n", + " dataset_name,\n", + " rotation=90,\n", + " labelpad=10,\n", + " ha=\"center\",\n", + " va=\"center\",\n", + " fontsize=12,\n", + " fontweight=\"bold\",\n", + " )\n", " # Column 1\n", - " sns.histplot(ax=axs[i, 0], data=CAT[CAT[\"Dataset\"] == dataset_name], x=\"CE\", binwidth=5, stat=\"percent\")\n", + " sns.histplot(\n", + " ax=axs[i, 0],\n", + " data=CAT[CAT[\"Dataset\"] == dataset_name],\n", + " x=\"CE\",\n", + " binwidth=5,\n", + " stat=\"percent\",\n", + " )\n", " axs[i, 0].set_xlim(0, 100)\n", " axs[i, 0].set_xticks([0, 25, 50, 75, 100])\n", " axs[i, 0].set_xlabel(\"Collision energy (eV)\", fontsize=12)\n", - " axs[i, 0].tick_params(axis='both', which='major', labelsize=11)\n", + " axs[i, 0].tick_params(axis=\"both\", which=\"major\", labelsize=11)\n", "\n", " # Column 2\n", - " sns.histplot(ax=axs[i, 1], data=CAT[CAT[\"Dataset\"] == dataset_name], x=\"num_peaks\", binwidth=10, stat=\"percent\")\n", - " axs[i, 1].set_xticks(list(range(0,260,50)))\n", + " sns.histplot(\n", + " ax=axs[i, 1],\n", + " data=CAT[CAT[\"Dataset\"] == dataset_name],\n", + " x=\"num_peaks\",\n", + " binwidth=10,\n", + " stat=\"percent\",\n", + " )\n", + " axs[i, 1].set_xticks(list(range(0, 260, 50)))\n", " axs[i, 1].set_xlim(0, 250)\n", " axs[i, 1].set_xlabel(\"Number of peaks\", fontsize=12)\n", " axs[i, 1].set_ylabel(\"\")\n", - " axs[i, 1].tick_params(axis='both', which='major', labelsize=11)\n", - " \n", + " axs[i, 1].tick_params(axis=\"both\", which=\"major\", labelsize=11)\n", + "\n", " # Column 3\n", - " mz = [item for sublist in CAT[CAT[\"Dataset\"] == dataset_name][\"peaks\"].apply(lambda x: x[\"mz\"]) for item in sublist] \n", + " mz = [\n", + " item\n", + " for sublist in CAT[CAT[\"Dataset\"] == dataset_name][\"peaks\"].apply(\n", + " lambda x: x[\"mz\"]\n", + " )\n", + " for item in sublist\n", + " ]\n", " sns.histplot(mz, ax=axs[i, 2], binwidth=10, stat=\"percent\")\n", - " axs[i, 2].set_xticks(list(range(0,550,100)))\n", + " axs[i, 2].set_xticks(list(range(0, 550, 100)))\n", " axs[i, 2].set_xlim(0, 500)\n", " axs[i, 2].set_xlabel(\"Peak m/z\", fontsize=12)\n", " axs[i, 2].set_ylabel(\"\")\n", - " axs[i, 2].axvline(x=125, color='black', linestyle='--', linewidth=1)\n", - " axs[i, 2].tick_params(axis='both', which='major', labelsize=11)\n", - " \n", + " axs[i, 2].axvline(x=125, color=\"black\", linestyle=\"--\", linewidth=1)\n", + " axs[i, 2].tick_params(axis=\"both\", which=\"major\", labelsize=11)\n", + "\n", " # Column 4\n", - " sns.histplot(ax=axs[i, 3], data=CAT[CAT[\"Dataset\"] == dataset_name], x=\"num_filtered_peaks\", hue=\"Precursor_type\", binwidth=1, stat=\"percent\", multiple=\"dodge\", hue_order=[\"[M+H]+\", \"[M-H]-\"])\n", + " sns.histplot(\n", + " ax=axs[i, 3],\n", + " data=CAT[CAT[\"Dataset\"] == dataset_name],\n", + " x=\"num_filtered_peaks\",\n", + " hue=\"Precursor_type\",\n", + " binwidth=1,\n", + " stat=\"percent\",\n", + " multiple=\"dodge\",\n", + " hue_order=[\"[M+H]+\", \"[M-H]-\"],\n", + " )\n", " axs[i, 3].set_xlim(1, 21)\n", " if i > 0:\n", " axs[i, 3].legend().set_visible(False)\n", " axs[i, 3].set_xlabel(\"Number of peaks\\nexplaining 80% of intensity\", fontsize=12)\n", " axs[i, 3].set_ylabel(\"\")\n", - " axs[i, 3].tick_params(axis='both', which='major', labelsize=11)\n", - "\n", - "\n", + " axs[i, 3].tick_params(axis=\"both\", which=\"major\", labelsize=11)\n", "\n", "\n", "# fig.savefig(f\"{home}/images/paper/data_hists.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20596,8 +22073,8 @@ "metadata": {}, "outputs": [], "source": [ - "#CAT[\"peak m/z\"] = CAT[\"peaks\"].apply(lambda x: x[\"mz\"])\n", - "#CAT[[\"Dataset\", \"CE\", \"num_peaks\", \"peak m/z\", \"num_filtered_peaks\"]].to_excel(f\"{home}/images/paper/SF15.xlsx\")" + "# CAT[\"peak m/z\"] = CAT[\"peaks\"].apply(lambda x: x[\"mz\"])\n", + "# CAT[[\"Dataset\", \"CE\", \"num_peaks\", \"peak m/z\", \"num_filtered_peaks\"]].to_excel(f\"{home}/images/paper/SF15.xlsx\")" ] }, { @@ -20607,7 +22084,11 @@ "outputs": [], "source": [ "from fiora.MS.spectral_scores import spectral_cosine\n", - "df_cas22[\"filtered_cosine\"] = df_cas22.apply(lambda x: spectral_cosine(x[\"sim_peaks\"], x[\"filtered_peaks\"], transform=np.sqrt) , axis=1)" + "\n", + "df_cas22[\"filtered_cosine\"] = df_cas22.apply(\n", + " lambda x: spectral_cosine(x[\"sim_peaks\"], x[\"filtered_peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + ")" ] }, { @@ -20667,7 +22148,7 @@ } ], "source": [ - "np.sum(df_cas[\"spectral_sqrt_cosine\"]>=0.70) / df_cas.shape[0]" + "np.sum(df_cas[\"spectral_sqrt_cosine\"] >= 0.70) / df_cas.shape[0]" ] }, { @@ -20697,7 +22178,7 @@ "outputs": [], "source": [ "df_test[\"CE_10\"] = np.ceil(df_test[\"CE\"] / 10) * 10\n", - "set_light_theme()\n" + "set_light_theme()" ] }, { @@ -20740,49 +22221,94 @@ } ], "source": [ - "fig, axs = plt.subplots(2, 2, figsize=(16,12), gridspec_kw={'width_ratios': [1, 1]}, sharex=True, sharey=True)\n", - "plt.subplots_adjust(hspace=0.12, wspace=0.05)#top=0.94, bottom=0.12, right=0.97, left=0.08)\n", + "fig, axs = plt.subplots(\n", + " 2,\n", + " 2,\n", + " figsize=(16, 12),\n", + " gridspec_kw={\"width_ratios\": [1, 1]},\n", + " sharex=True,\n", + " sharey=True,\n", + ")\n", + "plt.subplots_adjust(\n", + " hspace=0.12, wspace=0.05\n", + ") # top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", "magma10 = sns.color_palette(\"magma_r\", 18)[:10]\n", - "sns.boxplot(ax=axs[0,0], data=df_test, y=\"spectral_sqrt_cosine\", x=\"CE_10\", showfliers=False, palette=magma10, linewidth=1.5) # color=\"white\", linewidth=2, linecolor=\"black\")\n", - "axs[0,0].set_title(\"Fiora\", fontweight=\"bold\")\n", - "axs[0,0].set_ylabel(\"Cosine similarity\")\n", - "\n", - "sns.boxplot(ax=axs[0,1], data=df_test, y=\"ice_sqrt_cosine\", x=\"CE_10\", showfliers=False, palette=magma10, linewidth=1.5)\n", - "axs[0,1].set_title(\"ICEBERG\", fontweight=\"bold\")\n", - "axs[0,1].set_xlabel(\"Collision energy\")\n", - "axs[0, 1].xaxis.set_label_position('bottom')\n", - "axs[0,1].set_xticklabels(list(range(10, 110, 10)))\n", - "axs[0, 1].tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True) # Ensure tick labels visibility\n", + "sns.boxplot(\n", + " ax=axs[0, 0],\n", + " data=df_test,\n", + " y=\"spectral_sqrt_cosine\",\n", + " x=\"CE_10\",\n", + " showfliers=False,\n", + " palette=magma10,\n", + " linewidth=1.5,\n", + ") # color=\"white\", linewidth=2, linecolor=\"black\")\n", + "axs[0, 0].set_title(\"Fiora\", fontweight=\"bold\")\n", + "axs[0, 0].set_ylabel(\"Cosine similarity\")\n", + "\n", + "sns.boxplot(\n", + " ax=axs[0, 1],\n", + " data=df_test,\n", + " y=\"ice_sqrt_cosine\",\n", + " x=\"CE_10\",\n", + " showfliers=False,\n", + " palette=magma10,\n", + " linewidth=1.5,\n", + ")\n", + "axs[0, 1].set_title(\"ICEBERG\", fontweight=\"bold\")\n", + "axs[0, 1].set_xlabel(\"Collision energy\")\n", + "axs[0, 1].xaxis.set_label_position(\"bottom\")\n", + "axs[0, 1].set_xticklabels(list(range(10, 110, 10)))\n", + "axs[0, 1].tick_params(\n", + " axis=\"x\", which=\"both\", bottom=True, top=False, labelbottom=True\n", + ") # Ensure tick labels visibility\n", "axs[0, 1].xaxis.label.set_visible(True)\n", "\n", - "sns.boxplot(ax=axs[1,0], data=df_test, y=\"cfm_sqrt_cosine\", x=\"CE_10\", showfliers=False, palette=magma10, linewidth=1.5)\n", - "axs[1,0].set_title(\"CFM-ID\", fontweight=\"bold\")\n", - "axs[1,0].set_xlabel(\"Collision energy\")\n", - "axs[1,0].set_ylabel(\"Cosine similarity\")\n", + "sns.boxplot(\n", + " ax=axs[1, 0],\n", + " data=df_test,\n", + " y=\"cfm_sqrt_cosine\",\n", + " x=\"CE_10\",\n", + " showfliers=False,\n", + " palette=magma10,\n", + " linewidth=1.5,\n", + ")\n", + "axs[1, 0].set_title(\"CFM-ID\", fontweight=\"bold\")\n", + "axs[1, 0].set_xlabel(\"Collision energy\")\n", + "axs[1, 0].set_ylabel(\"Cosine similarity\")\n", "axs[1, 1].remove()\n", "\n", "for i in range(2):\n", " for j in range(2):\n", " if axs[i, j].has_data(): # Check if the subplot has data\n", " for tick, label in zip(axs[i, j].get_xticks(), axs[i, j].get_xticklabels()):\n", - " count = len(df_test[df_test[\"CE_10\"] == int(label.get_text())]) # Count data points per category\n", + " count = len(\n", + " df_test[df_test[\"CE_10\"] == int(label.get_text())]\n", + " ) # Count data points per category\n", " axs[i, j].text(\n", - " tick, axs[i, j].get_ylim()[0] + 0.04, f\"n={count}\", \n", - " ha='center', va='top', fontsize=10.25\n", + " tick,\n", + " axs[i, j].get_ylim()[0] + 0.04,\n", + " f\"n={count}\",\n", + " ha=\"center\",\n", + " va=\"top\",\n", + " fontsize=10.25,\n", " )\n", " if i == 0 and j == 1:\n", " axs[0, 0].text(\n", - " tick, axs[0, 0].get_ylim()[0] + 0.04, f\"n={count}\", \n", - " ha='center', va='top', fontsize=10.25\n", + " tick,\n", + " axs[0, 0].get_ylim()[0] + 0.04,\n", + " f\"n={count}\",\n", + " ha=\"center\",\n", + " va=\"top\",\n", + " fontsize=10.25,\n", " )\n", "\n", - "plt.rc('axes', labelsize=14)\n", - "plt.rc('legend', fontsize=14)\n", + "plt.rc(\"axes\", labelsize=14)\n", + "plt.rc(\"legend\", fontsize=14)\n", "plt.ylim([-0.05, 1.05])\n", - "axs[0,0].tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", - "axs[0,1].tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", - "axs[1,0].tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", + "axs[0, 0].tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", + "axs[0, 1].tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", + "axs[1, 0].tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", "\n", "# fig.savefig(f\"{home}/images/paper/cosine_ce_wo_prec.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/cosine_ce.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20832,7 +22358,14 @@ "source": [ "fig, ax = plt.subplots(1, 1, figsize=(12, 6), sharex=True, sharey=True)\n", "\n", - "sns.boxplot(data=df_test, y=\"coverage\", x=\"CE_10\", showfliers=False, palette=magma10, linewidth=1.5) # color=\"white\", linewidth=2, linecolor=\"black\")\n", + "sns.boxplot(\n", + " data=df_test,\n", + " y=\"coverage\",\n", + " x=\"CE_10\",\n", + " showfliers=False,\n", + " palette=magma10,\n", + " linewidth=1.5,\n", + ") # color=\"white\", linewidth=2, linecolor=\"black\")\n", "ax.set_xticklabels(list(range(10, 110, 10)))\n", "ax.set_xlabel(\"Collision energy\")\n", "ax.set_ylabel(\"Peak intensity coverage\")\n", @@ -20840,19 +22373,20 @@ "# Add 'n=XXX' annotations\n", "for tick, label in zip(ax.get_xticks(), ax.get_xticklabels()):\n", " count = len(df_test[df_test[\"CE_10\"] == int(label.get_text())])\n", - " ax.text(tick, ax.get_ylim()[0] + 0.01, f\"n={count}\", \n", - " ha='center', va='top', fontsize=13)\n", + " ax.text(\n", + " tick, ax.get_ylim()[0] + 0.01, f\"n={count}\", ha=\"center\", va=\"top\", fontsize=13\n", + " )\n", "\n", "plt.ylim([0.45, 1.03])\n", - "plt.rc('axes', labelsize=14)\n", - "plt.rc('legend', fontsize=14)\n", - "ax.tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", - "#plt.subplots_adjust(bottom=0.2)\n", + "plt.rc(\"axes\", labelsize=14)\n", + "plt.rc(\"legend\", fontsize=14)\n", + "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", + "# plt.subplots_adjust(bottom=0.2)\n", "\n", "# fig.savefig(f\"{home}/images/paper/coverage_ce.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/coverage_ce.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/coverage_ce.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - "#df_test[[\"coverage\", \"CE\", \"CE_10\"]].to_excel(f\"{home}/images/paper/SF10.xlsx\")\n", + "# df_test[[\"coverage\", \"CE\", \"CE_10\"]].to_excel(f\"{home}/images/paper/SF10.xlsx\")\n", "\n", "plt.show()" ] @@ -20875,43 +22409,74 @@ ], "source": [ "magma = sns.color_palette(\"magma_r\", 4)\n", - "fig, axs = plt.subplots(1, 3, figsize=(18,6), gridspec_kw={'width_ratios': [1, 1, 1]}, sharey=True)\n", + "fig, axs = plt.subplots(\n", + " 1, 3, figsize=(18, 6), gridspec_kw={\"width_ratios\": [1, 1, 1]}, sharey=True\n", + ")\n", "plt.subplots_adjust(wspace=0.05)\n", "\n", - "sns.boxplot(ax=axs[0], data=df_cas22, y=\"spectral_sqrt_cosine\", x=\"NCE\", hue=\"NCE\", palette=magma[:3], showfliers=False)\n", - "#sns.boxplot(ax=axs[1], data=df_cas22, y=\"filtered_cosine\", x=\"NCE\", hue=\"NCE\", palette=magma[:3])\n", - "\n", - "sns.boxplot(ax=axs[1], data=df_cas22, y=\"cfm_sqrt_cosine\", x=\"NCE\", hue=\"NCE\", palette=magma[:3], showfliers=False)\n", - "sns.boxplot(ax=axs[2], data=df_cas22, y=\"ice_sqrt_cosine\", x=\"NCE\", hue=\"NCE\", palette=magma[:3], showfliers=False)\n", + "sns.boxplot(\n", + " ax=axs[0],\n", + " data=df_cas22,\n", + " y=\"spectral_sqrt_cosine\",\n", + " x=\"NCE\",\n", + " hue=\"NCE\",\n", + " palette=magma[:3],\n", + " showfliers=False,\n", + ")\n", + "# sns.boxplot(ax=axs[1], data=df_cas22, y=\"filtered_cosine\", x=\"NCE\", hue=\"NCE\", palette=magma[:3])\n", + "\n", + "sns.boxplot(\n", + " ax=axs[1],\n", + " data=df_cas22,\n", + " y=\"cfm_sqrt_cosine\",\n", + " x=\"NCE\",\n", + " hue=\"NCE\",\n", + " palette=magma[:3],\n", + " showfliers=False,\n", + ")\n", + "sns.boxplot(\n", + " ax=axs[2],\n", + " data=df_cas22,\n", + " y=\"ice_sqrt_cosine\",\n", + " x=\"NCE\",\n", + " hue=\"NCE\",\n", + " palette=magma[:3],\n", + " showfliers=False,\n", + ")\n", "axs[0].set_title(\"Fiora\", fontweight=\"bold\", fontsize=16)\n", - "axs[1].set_title(\"CFM-ID\", fontweight=\"bold\",fontsize=16)\n", + "axs[1].set_title(\"CFM-ID\", fontweight=\"bold\", fontsize=16)\n", "axs[2].set_title(\"ICEBERG\", fontweight=\"bold\", fontsize=16)\n", "axs[0].set_ylabel(\"Cosine similarity\")\n", "for ax in axs:\n", " ax.get_legend().remove()\n", "\n", - "plt.rc('axes', labelsize=14)\n", - "plt.rc('legend', fontsize=14)\n", + "plt.rc(\"axes\", labelsize=14)\n", + "plt.rc(\"legend\", fontsize=14)\n", "plt.ylim(-0.08, 1.04)\n", - "axs[0].tick_params(axis='both', which='major', labelsize=13)\n", - "axs[1].tick_params(axis='both', which='major', labelsize=13)\n", - "axs[2].tick_params(axis='both', which='major', labelsize=13)\n", + "axs[0].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "axs[1].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "axs[2].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", "\n", "for ax in axs:\n", " for tick, label in zip(ax.get_xticks(), ax.get_xticklabels()):\n", - " \n", " category = float(label.get_text()) # Extract the label text\n", - " count = len(df_cas22[df_cas22[\"NCE\"] == category]) # Count samples for the category\n", + " count = len(\n", + " df_cas22[df_cas22[\"NCE\"] == category]\n", + " ) # Count samples for the category\n", " ax.text(\n", - " tick, ax.get_ylim()[0] + 0.05, f\"n={count}\",\n", - " ha='center', va='top', fontsize=13\n", + " tick,\n", + " ax.get_ylim()[0] + 0.05,\n", + " f\"n={count}\",\n", + " ha=\"center\",\n", + " va=\"top\",\n", + " fontsize=13,\n", " )\n", "\n", "# fig.savefig(f\"{home}/images/paper/casmi22_nce_wo_prec.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/casmi22_nce.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/casmi22_nce.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -20920,7 +22485,7 @@ "metadata": {}, "outputs": [], "source": [ - "#df_cas22[[\"ChallengeName\", \"NCE\", \"spectral_sqrt_cosine\", \"ice_sqrt_cosine\", \"cfm_sqrt_cosine\"]].to_excel(f\"{home}/images/paper/SF13.xlsx\")" + "# df_cas22[[\"ChallengeName\", \"NCE\", \"spectral_sqrt_cosine\", \"ice_sqrt_cosine\", \"cfm_sqrt_cosine\"]].to_excel(f\"{home}/images/paper/SF13.xlsx\")" ] }, { @@ -20950,14 +22515,27 @@ "set_light_theme()\n", "\n", "\n", - "fig, axs = plt.subplots(2, 1, figsize=(12, 14), sharex=True, gridspec_kw={'height_ratios': [1, 5]})\n", - "plt.subplots_adjust(hspace=0.025)#right=0.975, left=0.11)\n", - "#sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "plt.rc('legend', loc=\"upper center\")\n", - "sns.kdeplot(ax=axs[0], data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])], x=\"coverage\", bw_adjust=0.2, color=\"black\", fill=True, multiple=\"layer\", hue=\"Dataset\", common_norm=False, palette=tri_palette[1:]) #hue=\"Precursor_type\", \n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", - "#axs[0].legend(loc='upper center')\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(12, 14), sharex=True, gridspec_kw={\"height_ratios\": [1, 5]}\n", + ")\n", + "plt.subplots_adjust(hspace=0.025) # right=0.975, left=0.11)\n", + "# sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "plt.rc(\"legend\", loc=\"upper center\")\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", + " x=\"coverage\",\n", + " bw_adjust=0.2,\n", + " color=\"black\",\n", + " fill=True,\n", + " multiple=\"layer\",\n", + " hue=\"Dataset\",\n", + " common_norm=False,\n", + " palette=tri_palette[1:],\n", + ") # hue=\"Precursor_type\",\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", + "# axs[0].legend(loc='upper center')\n", "plt.xlim([-0.05, 1.05])\n", "axs[1].set_ylim([-0.05, 1.05])\n", "# Generating x values\n", @@ -20965,11 +22543,21 @@ "y = np.sqrt(x)\n", "sns.lineplot(x=x, y=y, color=\"black\", linestyle=\"dotted\")\n", "\n", - "#axs[0].set_title(\"Impact of coverage on cosine scores\")\n", - "sns.scatterplot(ax=axs[1], data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])], x=\"coverage\", y=\"spectral_sqrt_cosine\", hue=\"Dataset\", style=\"Dataset\", markers=[(4,1,0), \".\"], s=100, palette=tri_palette[1:]) #, hue_norm=(0, 1), palette=bluepink_grad)\n", + "# axs[0].set_title(\"Impact of coverage on cosine scores\")\n", + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", + " x=\"coverage\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " hue=\"Dataset\",\n", + " style=\"Dataset\",\n", + " markers=[(4, 1, 0), \".\"],\n", + " s=100,\n", + " palette=tri_palette[1:],\n", + ") # , hue_norm=(0, 1), palette=bluepink_grad)\n", "axs[1].set_ylabel(\"Cosine similarity\")\n", "axs[1].set_xlabel(\"Peak intensity coverage\")\n", - "axs[1].legend(title=\"Dataset\", loc='upper left')\n", + "axs[1].legend(title=\"Dataset\", loc=\"upper left\")\n", "# fig.savefig(f\"{home}/images/paper/coverage_top_only.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/coverage_top_only.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/coverage_top_only.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -21003,52 +22591,90 @@ ], "source": [ "set_light_theme()\n", - "fig, axs = plt.subplots(2, 2, figsize=(14, 14), sharex=False, sharey=False, gridspec_kw={'height_ratios': [1, 5], 'width_ratios': [5, 1]})\n", - "plt.subplots_adjust(hspace=0.025, wspace=0.025)#hspace=0.025)#right=0.975, left=0.11)\n", - "#sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "plt.rc('legend', loc=\"upper center\")\n", - "axs[0,0].tick_params(axis='both', labelsize=13)\n", - "axs[1,0].tick_params(axis='both', labelsize=13)\n", - "axs[1,1].tick_params(axis='both', labelsize=13)\n", - "plt.rc('axes', labelsize=20)\n", - "plt.rc('legend', fontsize=14)\n", - "\n", - "\n", - "sns.kdeplot(ax=axs[0,0], data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])], x=\"coverage\", bw_adjust=0.2, color=\"black\", fill=True, multiple=\"layer\", hue=\"Dataset\", common_norm=False, palette=tri_palette[1:]) #hue=\"Precursor_type\", \n", - "axs[0,0].spines['top'].set_visible(False)\n", - "axs[0,0].spines['right'].set_visible(False)\n", - "axs[0,0].set_xlim([-0.05, 1.05])\n", - "axs[0,0].set_xlabel(\"\")\n", - "axs[0,0].set_xticklabels(\"\")\n", - "#axs[0].legend(loc='upper center')\n", - "axs[1,0].set_xlim([-0.05, 1.05])\n", - "axs[1,0].set_ylim([-0.05, 1.05])\n", - "axs[1, 0].set_aspect('equal')\n", + "fig, axs = plt.subplots(\n", + " 2,\n", + " 2,\n", + " figsize=(14, 14),\n", + " sharex=False,\n", + " sharey=False,\n", + " gridspec_kw={\"height_ratios\": [1, 5], \"width_ratios\": [5, 1]},\n", + ")\n", + "plt.subplots_adjust(hspace=0.025, wspace=0.025) # hspace=0.025)#right=0.975, left=0.11)\n", + "# sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "plt.rc(\"legend\", loc=\"upper center\")\n", + "axs[0, 0].tick_params(axis=\"both\", labelsize=13)\n", + "axs[1, 0].tick_params(axis=\"both\", labelsize=13)\n", + "axs[1, 1].tick_params(axis=\"both\", labelsize=13)\n", + "plt.rc(\"axes\", labelsize=20)\n", + "plt.rc(\"legend\", fontsize=14)\n", + "\n", + "\n", + "sns.kdeplot(\n", + " ax=axs[0, 0],\n", + " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", + " x=\"coverage\",\n", + " bw_adjust=0.2,\n", + " color=\"black\",\n", + " fill=True,\n", + " multiple=\"layer\",\n", + " hue=\"Dataset\",\n", + " common_norm=False,\n", + " palette=tri_palette[1:],\n", + ") # hue=\"Precursor_type\",\n", + "axs[0, 0].spines[\"top\"].set_visible(False)\n", + "axs[0, 0].spines[\"right\"].set_visible(False)\n", + "axs[0, 0].set_xlim([-0.05, 1.05])\n", + "axs[0, 0].set_xlabel(\"\")\n", + "axs[0, 0].set_xticklabels(\"\")\n", + "# axs[0].legend(loc='upper center')\n", + "axs[1, 0].set_xlim([-0.05, 1.05])\n", + "axs[1, 0].set_ylim([-0.05, 1.05])\n", + "axs[1, 0].set_aspect(\"equal\")\n", "# Generating x values\n", "x = np.linspace(0, 1, 200)\n", "y = np.sqrt(x)\n", - "sns.lineplot(x=x, y=y, color=\"black\", linestyle=\"dotted\", ax=axs[1,0])\n", - "\n", - "\n", - "#axs[0].set_title(\"Impact of coverage on cosine scores\")\n", - "sns.scatterplot(ax=axs[1,0], data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])], x=\"coverage\", y=\"spectral_sqrt_cosine\", hue=\"Dataset\", style=\"Dataset\", markers=[(4,1,0), \".\"], s=100, palette=tri_palette[1:]) #, hue_norm=(0, 1), palette=bluepink_grad)\n", - "axs[1,0].set_ylabel(\"Cosine similarity\")\n", - "axs[1,0].set_xlabel(\"Peak intensity coverage\")\n", - "axs[1,0].legend(title=\"Dataset\", loc='upper left')\n", - "\n", + "sns.lineplot(x=x, y=y, color=\"black\", linestyle=\"dotted\", ax=axs[1, 0])\n", + "\n", + "\n", + "# axs[0].set_title(\"Impact of coverage on cosine scores\")\n", + "sns.scatterplot(\n", + " ax=axs[1, 0],\n", + " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", + " x=\"coverage\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " hue=\"Dataset\",\n", + " style=\"Dataset\",\n", + " markers=[(4, 1, 0), \".\"],\n", + " s=100,\n", + " palette=tri_palette[1:],\n", + ") # , hue_norm=(0, 1), palette=bluepink_grad)\n", + "axs[1, 0].set_ylabel(\"Cosine similarity\")\n", + "axs[1, 0].set_xlabel(\"Peak intensity coverage\")\n", + "axs[1, 0].legend(title=\"Dataset\", loc=\"upper left\")\n", "\n", "\n", "# Box plots\n", "\n", - "sns.kdeplot(ax=axs[1,1], data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])], y=\"spectral_sqrt_cosine\", bw_adjust=0.2, color=\"black\", fill=True, multiple=\"layer\", hue=\"Dataset\", common_norm=False, palette=tri_palette[1:]) #hue=\"Precursor_type\", \n", - "#sns.boxplot(ax=axs[1,1], data=CAT[CAT[\"Dataset\"] != \"Test split\"], y=\"spectral_sqrt_cosine\", hue=\"Dataset\", palette=tri_palette[1:])\n", - "\n", - "axs[1,1].set_ylim(axs[1,0].get_ylim())\n", - "axs[1,1].set_ylabel(\"\")\n", - "axs[1,1].set_yticklabels(\"\")\n", - "axs[1,1].legend().remove()\n", - "axs[1,1].spines['top'].set_visible(False)\n", - "axs[1,1].spines['right'].set_visible(False)\n", + "sns.kdeplot(\n", + " ax=axs[1, 1],\n", + " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", + " y=\"spectral_sqrt_cosine\",\n", + " bw_adjust=0.2,\n", + " color=\"black\",\n", + " fill=True,\n", + " multiple=\"layer\",\n", + " hue=\"Dataset\",\n", + " common_norm=False,\n", + " palette=tri_palette[1:],\n", + ") # hue=\"Precursor_type\",\n", + "# sns.boxplot(ax=axs[1,1], data=CAT[CAT[\"Dataset\"] != \"Test split\"], y=\"spectral_sqrt_cosine\", hue=\"Dataset\", palette=tri_palette[1:])\n", + "\n", + "axs[1, 1].set_ylim(axs[1, 0].get_ylim())\n", + "axs[1, 1].set_ylabel(\"\")\n", + "axs[1, 1].set_yticklabels(\"\")\n", + "axs[1, 1].legend().remove()\n", + "axs[1, 1].spines[\"top\"].set_visible(False)\n", + "axs[1, 1].spines[\"right\"].set_visible(False)\n", "fig.delaxes(axs[0, 1])\n", "\n", "for ax in axs.flat:\n", @@ -21057,16 +22683,16 @@ " ax.set_xlabel(ax.get_xlabel(), fontsize=14)\n", " ax.set_ylabel(ax.get_ylabel(), fontsize=14)\n", "\n", - "axs[0,0].tick_params(axis='both', labelsize=13)\n", - "axs[1,0].tick_params(axis='both', labelsize=13)\n", - "axs[1,1].tick_params(axis='both', labelsize=13)\n", - "plt.rc('axes', labelsize=20)\n", - "plt.rc('legend', fontsize=14)\n", + "axs[0, 0].tick_params(axis=\"both\", labelsize=13)\n", + "axs[1, 0].tick_params(axis=\"both\", labelsize=13)\n", + "axs[1, 1].tick_params(axis=\"both\", labelsize=13)\n", + "plt.rc(\"axes\", labelsize=20)\n", + "plt.rc(\"legend\", fontsize=14)\n", "\n", "\n", "# fig.savefig(f\"{home}/images/paper/coverage.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/coverage.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", - "#fig.savefig(f\"{home}/images/paper/coverage.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", + "# fig.savefig(f\"{home}/images/paper/coverage.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", "plt.show()" ] @@ -21147,15 +22773,38 @@ } ], "source": [ - "fig, axs = plt.subplots(2, 1, figsize=(6, 7), sharex=True, gridspec_kw={'height_ratios': [1, 5]})\n", - "plt.subplots_adjust(hspace=0.05)#right=0.975, left=0.11)\n", - "#sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\", \n", - "sns.kdeplot(ax=axs[0], data=df_test, x=\"coverage\", bw_adjust=0.2, color=\"black\", fill=True, multiple=\"layer\", hue=\"Dataset\", common_norm=False, palette=tri_palette[1:]) #hue=\"Precursor_type\", \n", - "axs[0].spines['top'].set_visible(False)\n", - "axs[0].spines['right'].set_visible(False)\n", + "fig, axs = plt.subplots(\n", + " 2, 1, figsize=(6, 7), sharex=True, gridspec_kw={\"height_ratios\": [1, 5]}\n", + ")\n", + "plt.subplots_adjust(hspace=0.05) # right=0.975, left=0.11)\n", + "# sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", + "sns.kdeplot(\n", + " ax=axs[0],\n", + " data=df_test,\n", + " x=\"coverage\",\n", + " bw_adjust=0.2,\n", + " color=\"black\",\n", + " fill=True,\n", + " multiple=\"layer\",\n", + " hue=\"Dataset\",\n", + " common_norm=False,\n", + " palette=tri_palette[1:],\n", + ") # hue=\"Precursor_type\",\n", + "axs[0].spines[\"top\"].set_visible(False)\n", + "axs[0].spines[\"right\"].set_visible(False)\n", "\n", "axs[0].set_title(\"Impact of coverage on cosine scores\")\n", - "sns.scatterplot(ax=axs[1], data=df_test, x=\"coverage\", y=\"spectral_sqrt_cosine\", hue=\"Dataset\", style=\"Dataset\", markers=[\".\", \"X\", \"*\"][1:], marker=\".\", palette=tri_palette[1:]) #, hue_norm=(0, 1), palette=bluepink_grad)\n", + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=df_test,\n", + " x=\"coverage\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " hue=\"Dataset\",\n", + " style=\"Dataset\",\n", + " markers=[\".\", \"X\", \"*\"][1:],\n", + " marker=\".\",\n", + " palette=tri_palette[1:],\n", + ") # , hue_norm=(0, 1), palette=bluepink_grad)\n", "axs[1].set_ylabel(\"Cosine similarity\")\n", "axs[1].set_xlabel(\"Peak intensity coverage\")\n", "plt.show()" @@ -21332,7 +22981,11 @@ ], "source": [ "ids = [17, 58, 68, 75, 102, 107, 128, 145, 163, 205]\n", - "example_input = df_cas[df_cas[\"is_priority\"]].loc[ids].copy()[[\"SMILES\", \"avg_CE\", \"Precursor_type\", \"Instrument_type\", \"peaks\"]]\n", + "example_input = (\n", + " df_cas[df_cas[\"is_priority\"]]\n", + " .loc[ids]\n", + " .copy()[[\"SMILES\", \"avg_CE\", \"Precursor_type\", \"Instrument_type\", \"peaks\"]]\n", + ")\n", "example_input[\"CE\"] = example_input[\"avg_CE\"].astype(int)\n", "example_input[\"Name\"] = [f\"Example_{i}\" for i in range(example_input.shape[0])]\n", "# example_input[[\"Name\", \"SMILES\", \"Precursor_type\", \"CE\", \"Instrument_type\"]].to_csv(\"../examples/example_input.csv\", index=False)\n", @@ -21346,11 +22999,13 @@ "outputs": [], "source": [ "for key in df_msnlib_test.iloc[0][\"Metabolite\"].match_stats.keys():\n", - " df_msnlib_test[key] = df_msnlib_test[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n", + " df_msnlib_test[key] = df_msnlib_test[\"Metabolite\"].apply(\n", + " lambda x: x.match_stats[key]\n", + " )\n", " df_test[key] = df_test[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n", " df_cas[key] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n", " df_cas22[key] = df_cas22[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n", - " C[key] = C[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n" + " C[key] = C[\"Metabolite\"].apply(lambda x: x.match_stats[key])" ] }, { @@ -21411,13 +23066,11 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "\n", "# columns = [\"Name\", \"Dataset\", \"lib\", \"origin\", \"SMILES\", \"InChI\", \"InChIKey\", \"Superclass\", \"Class\", \"Precursor_type\", \"Instrument_type\", \"PrecursorMZ\", \"CE\", \"NCE\", \"CCS\", \"CCS_pred\", \"RT\", \"RT_pred\", 'spectral_cosine', 'spectral_sqrt_cosine',\n", "# 'spectral_sqrt_cosine_wo_prec', 'spectral_refl_cosine', 'spectral_bias',\n", - "# 'spectral_sqrt_bias', 'spectral_sqrt_bias_wo_prec', 'spectral_refl_bias', \n", - "# 'cfm_sqrt_cosine', 'cfm_refl_cosine', 'cfm_sqrt_cosine_wo_prec', 'ice_name', \n", - "# 'ice_cosine','ice_sqrt_cosine', 'ice_sqrt_cosine_wo_prec', \"coverage\", \"coverage_wo_prec\", 'precursor_raw_prob', 'num_peak_matches', 'percent_peak_matches', 'tanimoto3'] \n", + "# 'spectral_sqrt_bias', 'spectral_sqrt_bias_wo_prec', 'spectral_refl_bias',\n", + "# 'cfm_sqrt_cosine', 'cfm_refl_cosine', 'cfm_sqrt_cosine_wo_prec', 'ice_name',\n", + "# 'ice_cosine','ice_sqrt_cosine', 'ice_sqrt_cosine_wo_prec', \"coverage\", \"coverage_wo_prec\", 'precursor_raw_prob', 'num_peak_matches', 'percent_peak_matches', 'tanimoto3']\n", "# C[columns].to_excel(f\"{home}/images/paper/SourceData.xlsx\")\n" ] }, @@ -21557,15 +23210,25 @@ "source": [ "reset_matplotlib()\n", "\n", - "df_print = df_cas[(df_cas[\"merged_sqrt_cosine\"] > 0.70) & (df_cas[\"merged_sqrt_bias\"] < 0.6)]\n", - "#df_print = df_test[(df_test[\"spectral_sqrt_cosine\"] > 0.85) & (df_test[\"spectral_sqrt_bias\"] < 0.6) & (df_test[\"lib\"] == \"MSDIAL\")]\n", + "df_print = df_cas[\n", + " (df_cas[\"merged_sqrt_cosine\"] > 0.70) & (df_cas[\"merged_sqrt_bias\"] < 0.6)\n", + "]\n", + "# df_print = df_test[(df_test[\"spectral_sqrt_cosine\"] > 0.85) & (df_test[\"spectral_sqrt_bias\"] < 0.6) & (df_test[\"lib\"] == \"MSDIAL\")]\n", "\n", "print(df_print.shape)\n", "for i, data in df_print.head(5).iterrows():\n", - " fig, axs = plt.subplots(1, 2, figsize=(9,3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - " img = data[\"Metabolite\"].draw(ax= axs[0])\n", + " fig, axs = plt.subplots(\n", + " 1, 2, figsize=(9, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " )\n", + " img = data[\"Metabolite\"].draw(ax=axs[0])\n", " print(i, data[\"ChallengeName\"])\n", - " sv.plot_spectrum(data, {\"peaks\": data[\"sim_peaks\"]}, highlight_matches=True, ppm_tolerance=200, ax=axs[1])\n", + " sv.plot_spectrum(\n", + " data,\n", + " {\"peaks\": data[\"sim_peaks\"]},\n", + " highlight_matches=True,\n", + " ppm_tolerance=200,\n", + " ax=axs[1],\n", + " )\n", " plt.show()" ] }, @@ -21864,20 +23527,44 @@ "# print(i)\n", "# sv.plot_spectrum(data, {\"peaks\": data[\"sim_peaks\"]}, highlight_matches=True, ppm_tolerance=200, ax=axs[1])\n", "# plt.show()\n", - " \n", - " \n", - "smallbutrelatable = [85073, 80082, 1916, 80698, 80083, 82664, 83920, 84096, 84102, 84110, 85073, 89637, 90222, 95985, 95988]\n", + "\n", + "\n", + "smallbutrelatable = [\n", + " 85073,\n", + " 80082,\n", + " 1916,\n", + " 80698,\n", + " 80083,\n", + " 82664,\n", + " 83920,\n", + " 84096,\n", + " 84102,\n", + " 84110,\n", + " 85073,\n", + " 89637,\n", + " 90222,\n", + " 95985,\n", + " 95988,\n", + "]\n", "\n", "for i in smallbutrelatable:\n", " data = df_test.loc[i]\n", " print(i, data[\"CE\"], data[\"Precursor_type\"])\n", " f = peaks_by_intensity(data[\"sim_peaks\"])\n", - " fig, ax = plt.subplots(1,1,figsize=(1.5,1.5))\n", + " fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5))\n", " Metabolite(f[0][2].split(\"//\")[0]).draw()\n", " plt.show()\n", - " fig, axs = plt.subplots(1, 2, figsize=(9,3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - " img = data[\"Metabolite\"].draw(ax= axs[0])\n", - " sv.plot_spectrum(data, {\"peaks\": data[\"sim_peaks\"]}, highlight_matches=True, ppm_tolerance=200, ax=axs[1])\n", + " fig, axs = plt.subplots(\n", + " 1, 2, figsize=(9, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " )\n", + " img = data[\"Metabolite\"].draw(ax=axs[0])\n", + " sv.plot_spectrum(\n", + " data,\n", + " {\"peaks\": data[\"sim_peaks\"]},\n", + " highlight_matches=True,\n", + " ppm_tolerance=200,\n", + " ax=axs[1],\n", + " )\n", " plt.show()" ] }, @@ -21905,7 +23592,9 @@ } ], "source": [ - "df_train[df_train[\"Name\"] == \"Indole-3-acetyl-L-alanine\"][\"Metabolite\"].iloc[0].tanimoto_similarity(data[\"Metabolite\"])\n" + "df_train[df_train[\"Name\"] == \"Indole-3-acetyl-L-alanine\"][\"Metabolite\"].iloc[\n", + " 0\n", + "].tanimoto_similarity(data[\"Metabolite\"])" ] }, { @@ -21948,25 +23637,33 @@ "print(f[0])\n", "Metabolite(f[0][2].split(\"//\")[0]).draw()\n", "plt.show()\n", - "fig, axs = plt.subplots(1, 2, figsize=(9,3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False)\n", - "img = data[\"Metabolite\"].draw(ax= axs[0])\n", + "fig, axs = plt.subplots(\n", + " 1, 2, figsize=(9, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + ")\n", + "img = data[\"Metabolite\"].draw(ax=axs[0])\n", "\n", "\n", - "sv.plot_spectrum(data, {\"peaks\": data[\"sim_peaks\"]}, highlight_matches=True, ppm_tolerance=200, ax=axs[1])\n", + "sv.plot_spectrum(\n", + " data,\n", + " {\"peaks\": data[\"sim_peaks\"]},\n", + " highlight_matches=True,\n", + " ppm_tolerance=200,\n", + " ax=axs[1],\n", + ")\n", "# fig.savefig(f\"{home}/images/paper/ex_mirror.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "plt.show()\n", "\n", "\n", - "fig, ax = plt.subplots(1, 1, figsize=(9,3))\n", + "fig, ax = plt.subplots(1, 1, figsize=(9, 3))\n", "sv.plot_spectrum(data, ax=ax)\n", "# fig.savefig(f\"{home}/images/paper/ex_original.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", "plt.show()\n", - "fig, ax = plt.subplots(1, 1, figsize=(9,3))\n", + "fig, ax = plt.subplots(1, 1, figsize=(9, 3))\n", "sv.plot_spectrum({\"peaks\": data[\"sim_peaks\"]}, ax=ax)\n", "# fig.savefig(f\"{home}/images/paper/ex_prediction.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -22151,8 +23848,8 @@ } ], "source": [ - "s=\"C=c1c[nH+]c2ccccc12\"\n", - "#s=\"C=C1C=[NH+]C2=C1C=CC=C2\"\n", + "s = \"C=c1c[nH+]c2ccccc12\"\n", + "# s=\"C=C1C=[NH+]C2=C1C=CC=C2\"\n", "m = Metabolite(s)\n", "m.draw()\n", "plt.show()\n", @@ -22183,7 +23880,7 @@ "\n", "for i in range(len(f)):\n", " print(f[i])\n", - " fig, ax = plt.subplots(1, 1, figsize=(2,2))\n", + " fig, ax = plt.subplots(1, 1, figsize=(2, 2))\n", " m = Metabolite(f[i][2].split(\"//\")[0])\n", " m.draw()\n", " plt.show()\n", @@ -22191,8 +23888,14 @@ " drawer = rdMolDraw2D.MolDraw2DSVG(500, 500)\n", " drawer.DrawMolecule(m.MOL)\n", " drawer.FinishDrawing()\n", - " cairosvg.svg2pdf(bytestring=drawer.GetDrawingText().encode(), write_to=f\"{home}/images/paper/molecule_f{i}.pdf\")\n", - " cairosvg.svg2svg(bytestring=drawer.GetDrawingText().encode(), write_to=f\"{home}/images/paper/molecule_f{i}.svg\")\n" + " cairosvg.svg2pdf(\n", + " bytestring=drawer.GetDrawingText().encode(),\n", + " write_to=f\"{home}/images/paper/molecule_f{i}.pdf\",\n", + " )\n", + " cairosvg.svg2svg(\n", + " bytestring=drawer.GetDrawingText().encode(),\n", + " write_to=f\"{home}/images/paper/molecule_f{i}.svg\",\n", + " )" ] }, { @@ -22215,15 +23918,21 @@ "from rdkit.Chem.Draw import rdMolDraw2D\n", "import cairosvg\n", "\n", - "#fig, ax = plt.subplots(1,1, figsize=(10, 10))\n", - "#data[\"Metabolite\"].draw(ax= ax)\n", + "# fig, ax = plt.subplots(1,1, figsize=(10, 10))\n", + "# data[\"Metabolite\"].draw(ax= ax)\n", "\n", "if False:\n", " drawer = rdMolDraw2D.MolDraw2DSVG(500, 500)\n", " drawer.DrawMolecule(data[\"Metabolite\"].MOL)\n", " drawer.FinishDrawing()\n", - " cairosvg.svg2pdf(bytestring=drawer.GetDrawingText().encode(), write_to=f\"{home}/images/paper/molecule.pdf\")\n", - " cairosvg.svg2svg(bytestring=drawer.GetDrawingText().encode(), write_to=f\"{home}/images/paper/molecule.svg\")" + " cairosvg.svg2pdf(\n", + " bytestring=drawer.GetDrawingText().encode(),\n", + " write_to=f\"{home}/images/paper/molecule.pdf\",\n", + " )\n", + " cairosvg.svg2svg(\n", + " bytestring=drawer.GetDrawingText().encode(),\n", + " write_to=f\"{home}/images/paper/molecule.svg\",\n", + " )" ] }, { @@ -22243,35 +23952,60 @@ } ], "source": [ - "\n", "import matplotlib.patches as mpatches\n", "\n", + "\n", "def double_mirrorplot(i, model_title=\"Fiora\"):\n", - " fig, axs = plt.subplots(1, 3, figsize=(16.8, 4.2), gridspec_kw={'width_ratios': [1, 3, 3]}, sharey=False)\n", - " \n", + " fig, axs = plt.subplots(\n", + " 1, 3, figsize=(16.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3, 3]}, sharey=False\n", + " )\n", + "\n", " plt.subplots_adjust(right=0.975, left=0.025)\n", - " \n", - " img = df_cas.loc[i][\"Metabolite\"].draw(ax= axs[0])\n", + "\n", + " img = df_cas.loc[i][\"Metabolite\"].draw(ax=axs[0])\n", "\n", " axs[0].grid(False)\n", - " axs[0].tick_params(axis='both', bottom=False, labelbottom=False, left=False, labelleft=False) \n", - " axs[0].set_title(df_cas.loc[i][\"NAME\"] + \"\\n(\" + df_cas.loc[i][\"ChallengeName\"]+ \")\")\n", + " axs[0].tick_params(\n", + " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " )\n", + " axs[0].set_title(\n", + " df_cas.loc[i][\"NAME\"] + \"\\n(\" + df_cas.loc[i][\"ChallengeName\"] + \")\"\n", + " )\n", " axs[0].imshow(img)\n", " axs[0].axis(\"off\")\n", "\n", - " sv.plot_spectrum({\"peaks\": df_cas.loc[i][\"peaks\"]}, {\"peaks\": df_cas.loc[i][\"merged_peaks\"]}, ax=axs[1])\n", + " sv.plot_spectrum(\n", + " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", + " {\"peaks\": df_cas.loc[i][\"merged_peaks\"]},\n", + " ax=axs[1],\n", + " )\n", " axs[1].title.set_text(model_title)\n", - " patch1 = mpatches.Patch(color='limegreen' if df_cas.loc[i][\"ice_sqrt_cosine\"] < df_cas.loc[i][\"merged_sqrt_cosine\"] else \"orangered\", label=f'cosine {df_cas.loc[i][\"merged_sqrt_cosine\"]:.02f}')\n", + " patch1 = mpatches.Patch(\n", + " color=\"limegreen\"\n", + " if df_cas.loc[i][\"ice_sqrt_cosine\"] < df_cas.loc[i][\"merged_sqrt_cosine\"]\n", + " else \"orangered\",\n", + " label=f\"cosine {df_cas.loc[i]['merged_sqrt_cosine']:.02f}\",\n", + " )\n", " axs[1].legend(handles=[patch1])\n", "\n", - " sv.plot_spectrum({\"peaks\": df_cas.loc[i][\"peaks\"]}, {\"peaks\": df_cas.loc[i][\"ice_peaks\"]} if df_cas.loc[i][\"ice_peaks\"] else {\"peaks\": {\"mz\": [0], \"intensity\": [0]}}, ax=axs[2])\n", - " axs[2].title.set_text(f'ICEBERG')\n", - " \n", - " patch2 = mpatches.Patch(color='limegreen' if df_cas.loc[i][\"ice_sqrt_cosine\"] > df_cas.loc[i][\"merged_sqrt_cosine\"] else \"orangered\", label=f'cosine {df_cas.loc[i][\"ice_sqrt_cosine\"]:.02f}', )\n", + " sv.plot_spectrum(\n", + " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", + " {\"peaks\": df_cas.loc[i][\"ice_peaks\"]}\n", + " if df_cas.loc[i][\"ice_peaks\"]\n", + " else {\"peaks\": {\"mz\": [0], \"intensity\": [0]}},\n", + " ax=axs[2],\n", + " )\n", + " axs[2].title.set_text(f\"ICEBERG\")\n", + "\n", + " patch2 = mpatches.Patch(\n", + " color=\"limegreen\"\n", + " if df_cas.loc[i][\"ice_sqrt_cosine\"] > df_cas.loc[i][\"merged_sqrt_cosine\"]\n", + " else \"orangered\",\n", + " label=f\"cosine {df_cas.loc[i]['ice_sqrt_cosine']:.02f}\",\n", + " )\n", " axs[2].legend(handles=[patch2])\n", - " \n", - " return fig, axs\n", - "\n" + "\n", + " return fig, axs" ] }, { @@ -22334,12 +24068,22 @@ } ], "source": [ - "fig, axs = plt.subplots(2,1,figsize=(8, 6), sharex=True)\n", + "fig, axs = plt.subplots(2, 1, figsize=(8, 6), sharex=True)\n", "\n", "\n", - "sns.histplot(ax=axs[0], data=df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], x=\"spectral_sqrt_cosine\", hue=\"lib\")\n", + "sns.histplot(\n", + " ax=axs[0],\n", + " data=df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"],\n", + " x=\"spectral_sqrt_cosine\",\n", + " hue=\"lib\",\n", + ")\n", "axs[0].set_title(\"[M+H]+ Test split\")\n", - "sns.histplot(ax=axs[1], data=df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"], x=\"spectral_sqrt_cosine\", hue=\"lib\")\n", + "sns.histplot(\n", + " ax=axs[1],\n", + " data=df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"],\n", + " x=\"spectral_sqrt_cosine\",\n", + " hue=\"lib\",\n", + ")\n", "axs[1].set_title(\"[M-H]- Test split\")\n", "plt.show()" ] @@ -22361,14 +24105,23 @@ } ], "source": [ - "\n", "print(df_cas.groupby(\"Precursor_type\")[\"coverage\"].median())\n", - "fig, axs = plt.subplots(1,2, figsize=(12, 6))\n", + "fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n", "\n", "sns.kdeplot(ax=axs[0], data=df_cas, x=\"coverage\", hue=\"Precursor_type\", bw_adjust=0.5)\n", - "sns.scatterplot(ax=axs[1], data=df_cas, x=\"coverage\", y=\"spectral_sqrt_cosine\", hue=\"Precursor_type\")\n", - "axs[0].axvline(x=df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(), color='black', linestyle='--')\n", - "axs[1].axvline(x=df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(), color='black', linestyle='--')\n", + "sns.scatterplot(\n", + " ax=axs[1], data=df_cas, x=\"coverage\", y=\"spectral_sqrt_cosine\", hue=\"Precursor_type\"\n", + ")\n", + "axs[0].axvline(\n", + " x=df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(),\n", + " color=\"black\",\n", + " linestyle=\"--\",\n", + ")\n", + "axs[1].axvline(\n", + " x=df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(),\n", + " color=\"black\",\n", + " linestyle=\"--\",\n", + ")\n", "plt.show()" ] }, @@ -22407,7 +24160,9 @@ } ], "source": [ - "sns.kdeplot(df_cas22, x=\"CE\", y=\"spectral_sqrt_cosine\", hue=\"Precursor_type\")#, hue=\"lib\")\n", + "sns.kdeplot(\n", + " df_cas22, x=\"CE\", y=\"spectral_sqrt_cosine\", hue=\"Precursor_type\"\n", + ") # , hue=\"lib\")\n", "plt.show()" ] }, @@ -22446,8 +24201,9 @@ } ], "source": [ - "df_test[\"pp_score_dif\"] = df_test[\"spectral_sqrt_cosine\"] - df_test[\"spectral_sqrt_cosine_wo_prec\"]\n", - "\n" + "df_test[\"pp_score_dif\"] = (\n", + " df_test[\"spectral_sqrt_cosine\"] - df_test[\"spectral_sqrt_cosine_wo_prec\"]\n", + ")" ] }, { @@ -22467,7 +24223,7 @@ } ], "source": [ - "sns.boxplot(df_test, x=\"pp_score_dif\")#, hue=\"lib\")\n", + "sns.boxplot(df_test, x=\"pp_score_dif\") # , hue=\"lib\")\n", "plt.xlim(-0.5, 0.5)\n", "plt.show()" ] @@ -22489,7 +24245,7 @@ } ], "source": [ - "sns.scatterplot(df_test, x=\"pp_score_dif\", y=\"pp\", hue=\"CE\")#, hue=\"lib\")\n", + "sns.scatterplot(df_test, x=\"pp_score_dif\", y=\"pp\", hue=\"CE\") # , hue=\"lib\")\n", "plt.xlim(-0.2, 0.6)\n", "plt.show()" ] @@ -22511,7 +24267,7 @@ } ], "source": [ - "sns.histplot(df_cas22, x=\"CE\", hue=\"Precursor_type\", binwidth=3)#, hue=\"lib\")\n", + "sns.histplot(df_cas22, x=\"CE\", hue=\"Precursor_type\", binwidth=3) # , hue=\"lib\")\n", "plt.xlim(0, 100)\n", "plt.show()" ] @@ -22533,7 +24289,9 @@ } ], "source": [ - "sns.histplot(df_cas, x=\"CE\", hue=\"Precursor_type\", hue_order=[\"[M+H]+\", \"[M-H]-\"], binwidth=3)#, hue=\"lib\")\n", + "sns.histplot(\n", + " df_cas, x=\"CE\", hue=\"Precursor_type\", hue_order=[\"[M+H]+\", \"[M-H]-\"], binwidth=3\n", + ") # , hue=\"lib\")\n", "plt.xlim(0, 100)\n", "plt.show()" ] @@ -22555,10 +24313,14 @@ } ], "source": [ - "fig, axs = plt.subplots(1,2,figsize=(8, 4))\n", + "fig, axs = plt.subplots(1, 2, figsize=(8, 4))\n", "\n", - "sns.boxplot(ax= axs[0], data=df_test, y=\"Precursor_ppm_error\", hue=\"lib\", showfliers=False)\n", - "sns.boxplot(ax= axs[1], data=df_test, y=\"Precursor_abs_error\", hue=\"lib\", showfliers=False)\n", + "sns.boxplot(\n", + " ax=axs[0], data=df_test, y=\"Precursor_ppm_error\", hue=\"lib\", showfliers=False\n", + ")\n", + "sns.boxplot(\n", + " ax=axs[1], data=df_test, y=\"Precursor_abs_error\", hue=\"lib\", showfliers=False\n", + ")\n", "plt.show()" ] }, @@ -22581,8 +24343,8 @@ "source": [ "print(list(df_test.columns))\n", "\n", - "#sns.kdeplot(df_test, x=\"CE\", y=\"spectral_sqrt_cosine\", hue=\"lib\")\n", - "#plt.show()" + "# sns.kdeplot(df_test, x=\"CE\", y=\"spectral_sqrt_cosine\", hue=\"lib\")\n", + "# plt.show()" ] }, { @@ -22602,7 +24364,14 @@ } ], "source": [ - "sns.kdeplot(data=df_test[(df_test[\"Precursor_ppm_error\"] < 1) & (df_test[\"Precursor_ppm_error\"] < 2)], y=\"spectral_sqrt_cosine\", x=\"Precursor_ppm_error\", hue=\"lib\")\n", + "sns.kdeplot(\n", + " data=df_test[\n", + " (df_test[\"Precursor_ppm_error\"] < 1) & (df_test[\"Precursor_ppm_error\"] < 2)\n", + " ],\n", + " y=\"spectral_sqrt_cosine\",\n", + " x=\"Precursor_ppm_error\",\n", + " hue=\"lib\",\n", + ")\n", "plt.xlim(0, 2)\n", "plt.show()" ] @@ -22651,15 +24420,28 @@ } ], "source": [ - "\n", "print(df_cast.groupby(\"Precursor_type\")[\"coverage\"].median())\n", - "fig, axs = plt.subplots(1,2, figsize=(12, 6))\n", + "fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n", "\n", "sns.kdeplot(ax=axs[0], data=df_cast, x=\"coverage\", hue=\"Precursor_type\", bw_adjust=0.5)\n", - "sns.scatterplot(ax=axs[1], data=df_cast, x=\"coverage\", y=\"spectral_sqrt_cosine\", hue=\"Precursor_type\")\n", - "axs[0].axvline(x=df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(), color='black', linestyle='--')\n", - "axs[1].axvline(x=df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(), color='black', linestyle='--')\n", - "plt.show()\n" + "sns.scatterplot(\n", + " ax=axs[1],\n", + " data=df_cast,\n", + " x=\"coverage\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " hue=\"Precursor_type\",\n", + ")\n", + "axs[0].axvline(\n", + " x=df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(),\n", + " color=\"black\",\n", + " linestyle=\"--\",\n", + ")\n", + "axs[1].axvline(\n", + " x=df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(),\n", + " color=\"black\",\n", + " linestyle=\"--\",\n", + ")\n", + "plt.show()" ] }, { @@ -22679,7 +24461,7 @@ } ], "source": [ - "for i,m in enumerate(df_cast[\"Metabolite\"]):\n", + "for i, m in enumerate(df_cast[\"Metabolite\"]):\n", " for m2 in df_train[\"Metabolite\"]:\n", " if m == m2:\n", " print(i)\n", @@ -22706,7 +24488,10 @@ "df_cas22[\"Name\"] = df_cas22[\"ChallengeName\"]\n", "df_msnlib_test[\"Name\"] = df_msnlib_test[\"NAME\"]\n", "df_msnlib_test[\"Instrument_type\"] = \"HCD\"\n", - "TEST = pd.concat([df_test[df_test[\"lib\"] != \"NIST\"], df_msnlib_test, df_cas, df_cas22], ignore_index=True)" + "TEST = pd.concat(\n", + " [df_test[df_test[\"lib\"] != \"NIST\"], df_msnlib_test, df_cas, df_cas22],\n", + " ignore_index=True,\n", + ")" ] }, { @@ -22728,7 +24513,8 @@ "source": [ "import importlib\n", "import fiora.IO.mgfWriter as mgfWriter\n", - "importlib.reload(mgfWriter)\n" + "\n", + "importlib.reload(mgfWriter)" ] }, { @@ -22739,9 +24525,28 @@ "source": [ "import fiora.IO.mgfWriter as mgfWriter\n", "\n", - "headers = [\"TITLE\", \"SMILES\", \"PRECURSORTYPE\", \"COLLISIONENERGY\", \"INSTRUMENTTYPE\", \"SOURCE\"]\n", - "mgfWriter.write_mgf(TEST, path=f\"{home}/data/archive/fiora_source_data/testing/ground_truth_spectra.mgf\", write_header=True, headers=headers, header_map={\"TITLE\": \"Name\", \"PRECURSORTYPE\": \"Precursor_type\", \"INSTRUMENTTYPE\": \"Instrument_type\", \"COLLISIONENERGY\": \"CE\", \"SOURCE\": \"lib\"}, annotation=False)\n", - "\n" + "headers = [\n", + " \"TITLE\",\n", + " \"SMILES\",\n", + " \"PRECURSORTYPE\",\n", + " \"COLLISIONENERGY\",\n", + " \"INSTRUMENTTYPE\",\n", + " \"SOURCE\",\n", + "]\n", + "mgfWriter.write_mgf(\n", + " TEST,\n", + " path=f\"{home}/data/archive/fiora_source_data/testing/ground_truth_spectra.mgf\",\n", + " write_header=True,\n", + " headers=headers,\n", + " header_map={\n", + " \"TITLE\": \"Name\",\n", + " \"PRECURSORTYPE\": \"Precursor_type\",\n", + " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", + " \"COLLISIONENERGY\": \"CE\",\n", + " \"SOURCE\": \"lib\",\n", + " },\n", + " annotation=False,\n", + ")" ] }, { @@ -22750,9 +24555,31 @@ "metadata": {}, "outputs": [], "source": [ - "headers = [\"TITLE\", \"SMILES\", \"PRECURSORTYPE\", \"COLLISIONENERGY\", \"INSTRUMENTTYPE\", \"SOURCE\", \"COMMENT\"]\n", - "TEST[\"COMMENT\"] = f\"\\\"In silico generated spectrum by Fiora (pre-release version)\\\"\"\n", - "mgfWriter.write_mgf(TEST, peak_tag=\"sim_peaks\", path=f\"{home}/data/archive/fiora_source_data/testing/fiora_predicted_spectra.mgf\", write_header=True, headers=headers, header_map={\"TITLE\": \"Name\", \"PRECURSORTYPE\": \"Precursor_type\", \"INSTRUMENTTYPE\": \"Instrument_type\", \"COLLISIONENERGY\": \"CE\", \"SOURCE\": \"lib\"}, annotation=False)\n" + "headers = [\n", + " \"TITLE\",\n", + " \"SMILES\",\n", + " \"PRECURSORTYPE\",\n", + " \"COLLISIONENERGY\",\n", + " \"INSTRUMENTTYPE\",\n", + " \"SOURCE\",\n", + " \"COMMENT\",\n", + "]\n", + "TEST[\"COMMENT\"] = f'\"In silico generated spectrum by Fiora (pre-release version)\"'\n", + "mgfWriter.write_mgf(\n", + " TEST,\n", + " peak_tag=\"sim_peaks\",\n", + " path=f\"{home}/data/archive/fiora_source_data/testing/fiora_predicted_spectra.mgf\",\n", + " write_header=True,\n", + " headers=headers,\n", + " header_map={\n", + " \"TITLE\": \"Name\",\n", + " \"PRECURSORTYPE\": \"Precursor_type\",\n", + " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", + " \"COLLISIONENERGY\": \"CE\",\n", + " \"SOURCE\": \"lib\",\n", + " },\n", + " annotation=False,\n", + ")" ] }, { @@ -22761,8 +24588,22 @@ "metadata": {}, "outputs": [], "source": [ - "TEST[\"COMMENT\"] = f\"\\\"In silico generated spectrum by ICEBERG\\\"\"\n", - "mgfWriter.write_mgf(TEST[TEST[\"Precursor_type\"] == \"[M+H]+\"], peak_tag=\"ice_peaks\", path=f\"{home}/data/archive/fiora_source_data/testing/iceberg_predicted_spectra.mgf\", write_header=True, headers=headers, header_map={\"TITLE\": \"Name\", \"PRECURSORTYPE\": \"Precursor_type\", \"INSTRUMENTTYPE\": \"Instrument_type\", \"COLLISIONENERGY\": \"CE\", \"SOURCE\": \"lib\"}, annotation=False)" + "TEST[\"COMMENT\"] = f'\"In silico generated spectrum by ICEBERG\"'\n", + "mgfWriter.write_mgf(\n", + " TEST[TEST[\"Precursor_type\"] == \"[M+H]+\"],\n", + " peak_tag=\"ice_peaks\",\n", + " path=f\"{home}/data/archive/fiora_source_data/testing/iceberg_predicted_spectra.mgf\",\n", + " write_header=True,\n", + " headers=headers,\n", + " header_map={\n", + " \"TITLE\": \"Name\",\n", + " \"PRECURSORTYPE\": \"Precursor_type\",\n", + " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", + " \"COLLISIONENERGY\": \"CE\",\n", + " \"SOURCE\": \"lib\",\n", + " },\n", + " annotation=False,\n", + ")" ] }, { @@ -22771,8 +24612,22 @@ "metadata": {}, "outputs": [], "source": [ - "TEST[\"COMMENT\"] = f\"\\\"In silico generated spectrum by CFM-ID (v4.4.7)\\\"\"\n", - "mgfWriter.write_mgf(TEST[~TEST[\"cfm_peaks\"].isna()], peak_tag=\"cfm_peaks\", path=f\"{home}/data/archive/fiora_source_data/testing/cfm-id_predicted_spectra.mgf\", write_header=True, headers=headers, header_map={\"TITLE\": \"Name\", \"PRECURSORTYPE\": \"Precursor_type\", \"INSTRUMENTTYPE\": \"Instrument_type\", \"COLLISIONENERGY\": \"CE\", \"SOURCE\": \"lib\"}, annotation=False)" + "TEST[\"COMMENT\"] = f'\"In silico generated spectrum by CFM-ID (v4.4.7)\"'\n", + "mgfWriter.write_mgf(\n", + " TEST[~TEST[\"cfm_peaks\"].isna()],\n", + " peak_tag=\"cfm_peaks\",\n", + " path=f\"{home}/data/archive/fiora_source_data/testing/cfm-id_predicted_spectra.mgf\",\n", + " write_header=True,\n", + " headers=headers,\n", + " header_map={\n", + " \"TITLE\": \"Name\",\n", + " \"PRECURSORTYPE\": \"Precursor_type\",\n", + " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", + " \"COLLISIONENERGY\": \"CE\",\n", + " \"SOURCE\": \"lib\",\n", + " },\n", + " annotation=False,\n", + ")" ] }, { @@ -22790,10 +24645,26 @@ "metadata": {}, "outputs": [], "source": [ - "TEST[\"peaks\"] = TEST[\"peaks\"].apply(lambda x: {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x)\n", - "TEST[\"sim_peaks\"] = TEST[\"sim_peaks\"].apply(lambda x: {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x)\n", - "TEST[\"ice_peaks\"] = TEST[\"ice_peaks\"].apply(lambda x: {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x)\n", - "TEST[\"cfm_peaks\"] = TEST[\"cfm_peaks\"].apply(lambda x: {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x)\n" + "TEST[\"peaks\"] = TEST[\"peaks\"].apply(\n", + " lambda x: (\n", + " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " )\n", + ")\n", + "TEST[\"sim_peaks\"] = TEST[\"sim_peaks\"].apply(\n", + " lambda x: (\n", + " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " )\n", + ")\n", + "TEST[\"ice_peaks\"] = TEST[\"ice_peaks\"].apply(\n", + " lambda x: (\n", + " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " )\n", + ")\n", + "TEST[\"cfm_peaks\"] = TEST[\"cfm_peaks\"].apply(\n", + " lambda x: (\n", + " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " )\n", + ")" ] }, { @@ -23044,17 +24915,52 @@ "metadata": {}, "outputs": [], "source": [ - "TEST[[\n", - " \"Name\", \"SMILES\", \"InChIKey\", \"group_id\", \"Precursor_type\", \"Instrument_type\", \"CE\", \"Dataset\", \"lib\", \"summary\",\n", - " \"peaks\", \"sim_peaks\", \"ice_peaks\", \"cfm_peaks\", \n", - " 'spectral_cosine', 'spectral_sqrt_cosine', 'spectral_sqrt_cosine_wo_prec', 'spectral_refl_cosine', 'spectral_bias',\n", - " 'spectral_sqrt_bias', 'spectral_sqrt_bias_wo_prec', 'spectral_refl_bias', 'steins_cosine', 'steins_bias',\n", - " 'cfm_sqrt_cosine', 'cfm_refl_cosine', 'cfm_sqrt_cosine_wo_prec',\n", - " 'cfm_steins', 'ice_name', 'ice_peaks', 'ice_cosine',\n", - " 'ice_sqrt_cosine', 'ice_sqrt_cosine_wo_prec', 'ice_refl_cosine', 'ice_steins',\n", - " \"RETENTIONTIME\", \"RT_pred\", \"CCS\", \"CCS_pred\", \n", - " \"coverage\", \"tanimoto\", \"tanimoto3\"\n", - " ]].to_csv(f\"{home}/data/archive/fiora_source_data/testing/dataframe.csv\")" + "TEST[\n", + " [\n", + " \"Name\",\n", + " \"SMILES\",\n", + " \"InChIKey\",\n", + " \"group_id\",\n", + " \"Precursor_type\",\n", + " \"Instrument_type\",\n", + " \"CE\",\n", + " \"Dataset\",\n", + " \"lib\",\n", + " \"summary\",\n", + " \"peaks\",\n", + " \"sim_peaks\",\n", + " \"ice_peaks\",\n", + " \"cfm_peaks\",\n", + " \"spectral_cosine\",\n", + " \"spectral_sqrt_cosine\",\n", + " \"spectral_sqrt_cosine_wo_prec\",\n", + " \"spectral_refl_cosine\",\n", + " \"spectral_bias\",\n", + " \"spectral_sqrt_bias\",\n", + " \"spectral_sqrt_bias_wo_prec\",\n", + " \"spectral_refl_bias\",\n", + " \"steins_cosine\",\n", + " \"steins_bias\",\n", + " \"cfm_sqrt_cosine\",\n", + " \"cfm_refl_cosine\",\n", + " \"cfm_sqrt_cosine_wo_prec\",\n", + " \"cfm_steins\",\n", + " \"ice_name\",\n", + " \"ice_peaks\",\n", + " \"ice_cosine\",\n", + " \"ice_sqrt_cosine\",\n", + " \"ice_sqrt_cosine_wo_prec\",\n", + " \"ice_refl_cosine\",\n", + " \"ice_steins\",\n", + " \"RETENTIONTIME\",\n", + " \"RT_pred\",\n", + " \"CCS\",\n", + " \"CCS_pred\",\n", + " \"coverage\",\n", + " \"tanimoto\",\n", + " \"tanimoto3\",\n", + " ]\n", + "].to_csv(f\"{home}/data/archive/fiora_source_data/testing/dataframe.csv\")" ] }, { @@ -23084,7 +24990,9 @@ "source": [ "df_msnlib_train[\"Name\"] = df_msnlib_train[\"NAME\"]\n", "df_msnlib_train[\"Instrument_type\"] = \"HCD\"\n", - "TRAIN = pd.concat([df_train[df_train[\"lib\"] != \"NIST\"], df_msnlib_train], ignore_index=True)" + "TRAIN = pd.concat(\n", + " [df_train[df_train[\"lib\"] != \"NIST\"], df_msnlib_train], ignore_index=True\n", + ")" ] }, { @@ -23125,13 +25033,45 @@ "metadata": {}, "outputs": [], "source": [ + "headers = [\n", + " \"TITLE\",\n", + " \"SMILES\",\n", + " \"PRECURSORTYPE\",\n", + " \"COLLISIONENERGY\",\n", + " \"INSTRUMENTTYPE\",\n", + " \"SOURCE\",\n", + "]\n", + "\n", + "mgfWriter.write_mgf(\n", + " TRAIN[TRAIN[\"datasplit\"] == \"training\"],\n", + " path=f\"{home}/data/archive/fiora_source_data/training/training_spectra.mgf\",\n", + " write_header=True,\n", + " headers=headers,\n", + " header_map={\n", + " \"TITLE\": \"Name\",\n", + " \"PRECURSORTYPE\": \"Precursor_type\",\n", + " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", + " \"COLLISIONENERGY\": \"CE\",\n", + " \"SOURCE\": \"lib\",\n", + " },\n", + " annotation=False,\n", + ")\n", + "mgfWriter.write_mgf(\n", + " TRAIN[TRAIN[\"datasplit\"] == \"validation\"],\n", + " path=f\"{home}/data/archive/fiora_source_data/training/validation_spectra.mgf\",\n", + " write_header=True,\n", + " headers=headers,\n", + " header_map={\n", + " \"TITLE\": \"Name\",\n", + " \"PRECURSORTYPE\": \"Precursor_type\",\n", + " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", + " \"COLLISIONENERGY\": \"CE\",\n", + " \"SOURCE\": \"lib\",\n", + " },\n", + " annotation=False,\n", + ")\n", "\n", - "headers = [\"TITLE\", \"SMILES\", \"PRECURSORTYPE\", \"COLLISIONENERGY\", \"INSTRUMENTTYPE\", \"SOURCE\"]\n", - "\n", - "mgfWriter.write_mgf(TRAIN[TRAIN[\"datasplit\"] == \"training\"], path=f\"{home}/data/archive/fiora_source_data/training/training_spectra.mgf\", write_header=True, headers=headers, header_map={\"TITLE\": \"Name\", \"PRECURSORTYPE\": \"Precursor_type\", \"INSTRUMENTTYPE\": \"Instrument_type\", \"COLLISIONENERGY\": \"CE\", \"SOURCE\": \"lib\"}, annotation=False)\n", - "mgfWriter.write_mgf(TRAIN[TRAIN[\"datasplit\"] == \"validation\"], path=f\"{home}/data/archive/fiora_source_data/training/validation_spectra.mgf\", write_header=True, headers=headers, header_map={\"TITLE\": \"Name\", \"PRECURSORTYPE\": \"Precursor_type\", \"INSTRUMENTTYPE\": \"Instrument_type\", \"COLLISIONENERGY\": \"CE\", \"SOURCE\": \"lib\"}, annotation=False)\n", - "\n", - "#TRAIN[TRAIN[\"datasplit\"] == \"validation\"] # to MGF" + "# TRAIN[TRAIN[\"datasplit\"] == \"validation\"] # to MGF" ] }, { @@ -23140,12 +25080,29 @@ "metadata": {}, "outputs": [], "source": [ - "TRAIN[\"peaks\"] = TRAIN[\"peaks\"].apply(lambda x: {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x)\n", + "TRAIN[\"peaks\"] = TRAIN[\"peaks\"].apply(\n", + " lambda x: (\n", + " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " )\n", + ")\n", "\n", - "TRAIN[[\n", - " \"Name\", \"SMILES\", \"InChIKey\", \"group_id\", \"datasplit\", \"Precursor_type\", \"Instrument_type\", \"CE\", \"lib\", \n", - " \"peaks\", \"summary\", \"RETENTIONTIME\", \"CCS\", \n", - " ]].to_csv(f\"{home}/data/archive/fiora_source_data/training/dataframe.csv\")" + "TRAIN[\n", + " [\n", + " \"Name\",\n", + " \"SMILES\",\n", + " \"InChIKey\",\n", + " \"group_id\",\n", + " \"datasplit\",\n", + " \"Precursor_type\",\n", + " \"Instrument_type\",\n", + " \"CE\",\n", + " \"lib\",\n", + " \"peaks\",\n", + " \"summary\",\n", + " \"RETENTIONTIME\",\n", + " \"CCS\",\n", + " ]\n", + "].to_csv(f\"{home}/data/archive/fiora_source_data/training/dataframe.csv\")" ] } ], diff --git a/notebooks/train_model.ipynb b/notebooks/train_model.ipynb index b98d7e0..ee1f97b 100644 --- a/notebooks/train_model.ipynb +++ b/notebooks/train_model.ipynb @@ -31,7 +31,7 @@ "import torch\n", "\n", "seed = 42\n", - "#torch.set_default_dtype(torch.float64)\n", + "# torch.set_default_dtype(torch.float64)\n", "torch.manual_seed(seed)\n", "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", @@ -44,18 +44,20 @@ "# Load Modules\n", "sys.path.append(\"..\")\n", "from os.path import expanduser\n", + "\n", "home = expanduser(\"~\")\n", "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree \n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", "import fiora.visualization.spectrum_visualizer as sv\n", "\n", "from sklearn.metrics import r2_score\n", "import scipy\n", "from rdkit import RDLogger\n", - "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f'Working with Python {sys.version}')\n" + "RDLogger.DisableLog(\"rdApp.*\")\n", + "\n", + "print(f\"Working with Python {sys.version}\")" ] }, { @@ -81,12 +83,15 @@ ], "source": [ "from typing import Literal\n", - "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\", \"MSnLib\"] = \"MSnLib\" #\"MSnLib\"\n", + "\n", + "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\", \"MSnLib\"] = \"MSnLib\" # \"MSnLib\"\n", "print(f\"Preparing {lib} library\")\n", "\n", - "debug_mode = False # Default: False\n", + "debug_mode = False # Default: False\n", "if debug_mode:\n", - " print(\"+++ This is a test run (debug mode) with a small subset of data points. Results are not representative. +++\")" + " print(\n", + " \"+++ This is a test run (debug mode) with a small subset of data points. Results are not representative. +++\"\n", + " )" ] }, { @@ -97,41 +102,45 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\"\n", - " }\n", + " \"name\": \"Name\",\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"ionization\": \"Ionization\",\n", + " \"precursor_mz\": \"PrecursorMZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + " \"ccs\": \"CCS\",\n", + "}\n", "\n", "\n", "#\n", "# Load specified libraries and align metadata\n", "#\n", "\n", + "\n", "def load_training_data():\n", - " if (\"NIST\" in lib or \"MSDIAL\" in lib):\n", + " if \"NIST\" in lib or \"MSDIAL\" in lib:\n", " data_path: str = f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\"\n", " elif lib == \"MSnLib\":\n", - " data_path: str = f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_v7_Sep25.csv\"\n", + " data_path: str = (\n", + " f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_v7_Sep25.csv\"\n", + " )\n", " else:\n", " raise NameError(f\"Unknown library selected {lib=}.\")\n", " L = LibraryLoader()\n", " df = L.load_from_csv(data_path)\n", " return df\n", "\n", + "\n", "df = load_training_data()\n", "\n", "# Restore dictionary values\n", "dict_columns = [\"peaks\", \"summary\"]\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", - " #df[col] = df[col].apply(ast.literal_eval)\n", - " \n", - "df['group_id'] = df['group_id'].astype(int)\n" + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " # df[col] = df[col].apply(ast.literal_eval)\n", + "\n", + "df[\"group_id\"] = df[\"group_id\"].astype(int)" ] }, { @@ -161,15 +170,14 @@ "\n", "\n", "if debug_mode:\n", - " df = df.iloc[:1000,:]\n", - " #df = df.iloc[5000:20000,:]\n", + " df = df.iloc[:1000, :]\n", + " # df = df.iloc[5000:20000,:]\n", "\n", "overwrite_setup_features = None\n", "if lib == \"MSnLib\":\n", " overwrite_setup_features = {\n", " \"instrument\": [\"HCD\"],\n", - " \"precursor_mode\": [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"]\n", - " \n", + " \"precursor_mode\": [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"],\n", " }\n", "\n", "\n", @@ -178,18 +186,38 @@ "\n", "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", - "covariate_encoder = CovariateFeatureEncoder(feature_list=[\"collision_energy\", \"molecular_weight\", \"precursor_mode\", \"instrument\", \"element_composition\"], sets_overwrite=overwrite_setup_features)\n", - "rt_encoder = CovariateFeatureEncoder(feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\", \"element_composition\"], sets_overwrite=overwrite_setup_features)\n", + "covariate_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\n", + " \"collision_energy\",\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " \"element_composition\",\n", + " ],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", + "rt_encoder = CovariateFeatureEncoder(\n", + " feature_list=[\n", + " \"molecular_weight\",\n", + " \"precursor_mode\",\n", + " \"instrument\",\n", + " \"element_composition\",\n", + " ],\n", + " sets_overwrite=overwrite_setup_features,\n", + ")\n", "\n", - "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit \n", - "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit \n", + "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", + "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", "\n", "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]) , axis=1)\n", + "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", "\n", - "#df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "df.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", + "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", + "df.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder),\n", + " axis=1,\n", + ")\n", "_ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)" ] }, @@ -220,7 +248,9 @@ ], "source": [ "mindex.create_fragmentation_trees()\n", - "list_of_mismatched_ids = mindex.add_fragmentation_trees_to_metabolite_list(df[\"Metabolite\"], graph_mismatch_policy=\"recompute\")\n", + "list_of_mismatched_ids = mindex.add_fragmentation_trees_to_metabolite_list(\n", + " df[\"Metabolite\"], graph_mismatch_policy=\"recompute\"\n", + ")\n", "print(f\"Total number of recomputed trees: {len(list_of_mismatched_ids)}\")" ] }, @@ -230,8 +260,13 @@ "metadata": {}, "outputs": [], "source": [ - "#df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "_ = df.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]), axis=1)" + "# df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "_ = df.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -248,7 +283,9 @@ } ], "source": [ - "df[\"num_peak_matches\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"num_peak_matches\"])\n", + "df[\"num_peak_matches\"] = df[\"Metabolite\"].apply(\n", + " lambda x: x.match_stats[\"num_peak_matches\"]\n", + ")\n", "print(sum(df[\"num_peak_matches\"] < 1))\n", "df = df[df[\"num_peak_matches\"] >= 2]" ] @@ -304,12 +341,12 @@ "source": [ "import torch_geometric as geom\n", "\n", - "if torch.cuda.is_available(): \n", + "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " dev = \"cuda:0\"\n", - "else: \n", - " dev = \"cpu\" \n", - " \n", + "else:\n", + " dev = \"cpu\"\n", + "\n", "print(f\"Running on device: {dev}\")" ] }, @@ -378,57 +415,50 @@ "outputs": [], "source": [ "model_params = {\n", - " 'param_tag': 'default',\n", - "\n", - " #GNN parameters\n", - " 'gnn_type': 'RGCNConv',\n", - " 'depth': 10, # 8\n", - " 'hidden_dimension': 300, # 300\n", - " 'residual_connections': False,\n", - " 'layer_stacking': True, # Avoid residual connections and layer stacking at the same time\n", - " 'embedding_aggregation': 'concat',\n", - " 'embedding_dimension': 300, # 300,\n", - " 'subgraph_features': True,\n", - " 'pooling_func': \"max\", # max or avg\n", - " 'layer_norm': True, \n", - " \n", + " \"param_tag\": \"default\",\n", + " # GNN parameters\n", + " \"gnn_type\": \"RGCNConv\",\n", + " \"depth\": 10, # 8\n", + " \"hidden_dimension\": 300, # 300\n", + " \"residual_connections\": False,\n", + " \"layer_stacking\": True, # Avoid residual connections and layer stacking at the same time\n", + " \"embedding_aggregation\": \"concat\",\n", + " \"embedding_dimension\": 300, # 300,\n", + " \"subgraph_features\": True,\n", + " \"pooling_func\": \"max\", # max or avg\n", + " \"layer_norm\": True,\n", " # Dense layers\n", - " 'dense_layers': 2, # 2 # Number of \"hidden\" dense layers, an additional output layer is always added\n", - " 'dense_dim': 500, # Set to None (then dense dim defaults to GNN output dimension (very large if layer stacking is active))\n", - "\n", + " \"dense_layers\": 2, # 2 # Number of \"hidden\" dense layers, an additional output layer is always added\n", + " \"dense_dim\": 500, # Set to None (then dense dim defaults to GNN output dimension (very large if layer stacking is active))\n", " # Dropout\n", - " 'input_dropout': 0.25, # 0.2,\n", - " 'latent_dropout': 0.25, # 0.1,\n", - " \n", + " \"input_dropout\": 0.25, # 0.2,\n", + " \"latent_dropout\": 0.25, # 0.1,\n", " # Dimensions\n", - " 'node_feature_layout': node_encoder.feature_numbers,\n", - " 'edge_feature_layout': bond_encoder.feature_numbers, \n", - " 'static_feature_dimension': geo_data[0][\"static_edge_features\"].shape[1],\n", - " 'static_rt_feature_dimension': geo_data[0][\"static_rt_features\"].shape[1],\n", - " 'output_dimension': len(DEFAULT_MODES) * 2, # per edge \n", - " \n", + " \"node_feature_layout\": node_encoder.feature_numbers,\n", + " \"edge_feature_layout\": bond_encoder.feature_numbers,\n", + " \"static_feature_dimension\": geo_data[0][\"static_edge_features\"].shape[1],\n", + " \"static_rt_feature_dimension\": geo_data[0][\"static_rt_features\"].shape[1],\n", + " \"output_dimension\": len(DEFAULT_MODES) * 2, # per edge\n", " # Keep track of how features are encoded\n", - " 'atom_features': node_encoder.feature_list,\n", - " 'atom_features': bond_encoder.feature_list,\n", - " 'setup_features': covariate_encoder.feature_list,\n", - " 'setup_features_categorical_set': covariate_encoder.categorical_sets,\n", - " 'rt_features': rt_encoder.feature_list,\n", - " \n", + " \"atom_features\": node_encoder.feature_list,\n", + " \"atom_features\": bond_encoder.feature_list,\n", + " \"setup_features\": covariate_encoder.feature_list,\n", + " \"setup_features_categorical_set\": covariate_encoder.categorical_sets,\n", + " \"rt_features\": rt_encoder.feature_list,\n", " # Set default flags (May be overwritten below)\n", - " 'prepare_additional_layers': False,\n", - " 'rt_supported': False,\n", - " 'ccs_supported': False,\n", - " 'version': \"x.x.x\"\n", - " \n", + " \"prepare_additional_layers\": False,\n", + " \"rt_supported\": False,\n", + " \"ccs_supported\": False,\n", + " \"version\": \"x.x.x\",\n", "}\n", "training_params = {\n", - " 'epochs': 300 if not debug_mode else 10, \n", - " 'batch_size': 32, # 256, # 256\n", + " \"epochs\": 300 if not debug_mode else 10,\n", + " \"batch_size\": 32, # 256, # 256\n", " #'train_val_split': 0.90,\n", - " 'learning_rate': 2e-4, # 4e-4, \n", - " 'weight_decay': 1e-5, #1e-4,\n", - " 'with_RT': False, # Turn off RT/CCS for initial trainings round\n", - " 'with_CCS': False\n", + " \"learning_rate\": 2e-4, # 4e-4,\n", + " \"weight_decay\": 1e-5, # 1e-4,\n", + " \"with_RT\": False, # Turn off RT/CCS for initial trainings round\n", + " \"with_CCS\": False,\n", "}" ] }, @@ -447,6 +477,7 @@ ], "source": [ "from fiora.GNN.FioraModel import FioraModel\n", + "\n", "model_snapshot = FioraModel(model_params)\n", "# Print num of parameters of model\n", "num_params = sum(p.numel() for p in model_snapshot.parameters() if p.requires_grad)\n", @@ -483,11 +514,16 @@ "source": [ "import numpy as np\n", "\n", + "\n", "# Subsample training data\n", "def subsample_keys(train_keys, val_keys, down_to_fraction: float):\n", - " train_sample = np.random.choice(train_keys, size=int(len(train_keys) * down_to_fraction), replace=False)\n", - " val_sample = np.random.choice(val_keys, size=int(len(val_keys) * down_to_fraction), replace=False)\n", - " return train_sample, val_sample\n" + " train_sample = np.random.choice(\n", + " train_keys, size=int(len(train_keys) * down_to_fraction), replace=False\n", + " )\n", + " val_sample = np.random.choice(\n", + " val_keys, size=int(len(val_keys) * down_to_fraction), replace=False\n", + " )\n", + " return train_sample, val_sample" ] }, { @@ -498,58 +534,117 @@ "source": [ "from fiora.GNN.SpectralTrainer import SpectralTrainer\n", "from fiora.GNN.FioraModel import FioraModel\n", - "from fiora.GNN.Losses import WeightedMSELoss, WeightedMSEMetric, WeightedMAELoss, WeightedMAEMetric, GraphwiseKLLoss, GraphwiseKLLossMetric\n", + "from fiora.GNN.Losses import (\n", + " WeightedMSELoss,\n", + " WeightedMSEMetric,\n", + " WeightedMAELoss,\n", + " WeightedMAEMetric,\n", + " GraphwiseKLLoss,\n", + " GraphwiseKLLossMetric,\n", + ")\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", "fiora = SimulationFramework(None, dev=dev)\n", "# fiora = SimulationFramework(None, dev=dev, with_RT=training_params[\"with_RT\"], with_CCS=training_params[\"with_CCS\"])\n", - "np.seterr(invalid='ignore')\n", + "np.seterr(invalid=\"ignore\")\n", "tag = \"training\"\n", "val_interval = 1\n", - "metric_dict= {\"mse\": GraphwiseKLLossMetric} #reduction=\"mean\" by default #WeightedMSEMetric\n", - "#loss_fn = WeightedMSELoss() # WeightedMSELoss()\n", + "metric_dict = {\n", + " \"mse\": GraphwiseKLLossMetric\n", + "} # reduction=\"mean\" by default #WeightedMSEMetric\n", + "# loss_fn = WeightedMSELoss() # WeightedMSELoss()\n", "loss_fn = GraphwiseKLLoss(reduction=\"mean\")\n", "all_together = False\n", "down_sample = False\n", "\n", "if all_together:\n", " val_interval = 200\n", - " metric_dict=None\n", - " loss_fn = torch.nn.MSELoss() \n", + " metric_dict = None\n", + " loss_fn = torch.nn.MSELoss()\n", "\n", - "def train_new_model(continue_with_model=None, model_params=model_params, training_params=training_params, tag=\"\"):\n", + "\n", + "def train_new_model(\n", + " continue_with_model=None,\n", + " model_params=model_params,\n", + " training_params=training_params,\n", + " tag=\"\",\n", + "):\n", " if continue_with_model:\n", " model = continue_with_model.to(dev)\n", " else:\n", " model = FioraModel(model_params).to(dev)\n", - " \n", + "\n", " # y_label = 'compiled_probsSQRT' # y_label = 'compiled_probsALL'\n", - " y_label = 'compiled_probsALL'\n", - " optimizer = torch.optim.Adam(model.parameters(), lr=training_params[\"learning_rate\"], weight_decay=training_params[\"weight_decay\"])\n", + " y_label = \"compiled_probsALL\"\n", + " optimizer = torch.optim.Adam(\n", + " model.parameters(),\n", + " lr=training_params[\"learning_rate\"],\n", + " weight_decay=training_params[\"weight_decay\"],\n", + " )\n", " if all_together:\n", - " trainer = SpectralTrainer(geo_data, y_tag=y_label, problem_type=\"regression\", only_training=True, metric_dict=metric_dict, split_by_group=True, seed=seed, device=dev)\n", + " trainer = SpectralTrainer(\n", + " geo_data,\n", + " y_tag=y_label,\n", + " problem_type=\"regression\",\n", + " only_training=True,\n", + " metric_dict=metric_dict,\n", + " split_by_group=True,\n", + " seed=seed,\n", + " device=dev,\n", + " )\n", " scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n", " else:\n", - " train_keys, val_keys = df[df[\"datasplit\"] == \"training\"][\"group_id\"].unique(), df[df[\"datasplit\"] == \"validation\"][\"group_id\"].unique()\n", + " train_keys, val_keys = (\n", + " df[df[\"datasplit\"] == \"training\"][\"group_id\"].unique(),\n", + " df[df[\"datasplit\"] == \"validation\"][\"group_id\"].unique(),\n", + " )\n", " if down_sample:\n", " train_fraction = 0.10\n", - " train_keys, val_keys = subsample_keys(train_keys, val_keys, train_fraction) # Downsample training data for test\n", - " print(f\"Sample down to {train_fraction * 100}% with {len(train_keys)} training and {len(val_keys)} validation compounds \")\n", - " trainer = SpectralTrainer(geo_data, y_tag=y_label, problem_type=\"regression\", train_keys=train_keys, val_keys=val_keys, metric_dict=metric_dict, split_by_group=True, seed=seed, device=dev)\n", - " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 8, factor=0.5, mode = 'min')\n", - "\n", - " \n", - " checkpoints = trainer.train(model, optimizer, loss_fn, scheduler=scheduler, batch_size=training_params['batch_size'], epochs=training_params[\"epochs\"], val_every_n_epochs=1, with_CCS=training_params[\"with_CCS\"], with_RT=training_params[\"with_RT\"], use_validation_mask=False, tag=tag) #, mask_name=\"compiled_validation_maskALL\") \n", + " train_keys, val_keys = subsample_keys(\n", + " train_keys, val_keys, train_fraction\n", + " ) # Downsample training data for test\n", + " print(\n", + " f\"Sample down to {train_fraction * 100}% with {len(train_keys)} training and {len(val_keys)} validation compounds \"\n", + " )\n", + " trainer = SpectralTrainer(\n", + " geo_data,\n", + " y_tag=y_label,\n", + " problem_type=\"regression\",\n", + " train_keys=train_keys,\n", + " val_keys=val_keys,\n", + " metric_dict=metric_dict,\n", + " split_by_group=True,\n", + " seed=seed,\n", + " device=dev,\n", + " )\n", + " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", + " optimizer, patience=8, factor=0.5, mode=\"min\"\n", + " )\n", + "\n", + " checkpoints = trainer.train(\n", + " model,\n", + " optimizer,\n", + " loss_fn,\n", + " scheduler=scheduler,\n", + " batch_size=training_params[\"batch_size\"],\n", + " epochs=training_params[\"epochs\"],\n", + " val_every_n_epochs=1,\n", + " with_CCS=training_params[\"with_CCS\"],\n", + " with_RT=training_params[\"with_RT\"],\n", + " use_validation_mask=False,\n", + " tag=tag,\n", + " ) # , mask_name=\"compiled_validation_maskALL\")\n", " print(checkpoints)\n", " return model, checkpoints, trainer\n", "\n", + "\n", "def simulate_all(model, DF):\n", " return fiora.simulate_all(DF, model)\n", "\n", - " \n", + "\n", "def test_model(model, DF, score=\"spectral_sqrt_cosine\", return_df=False):\n", " dft = simulate_all(model, DF)\n", - " \n", + "\n", " if return_df:\n", " return dft\n", " return dft[score].values" @@ -561,39 +656,48 @@ "metadata": {}, "outputs": [], "source": [ - "def grid_search(param_grid, base_model_params, base_training_params, store_models: bool = False, prefix=\"config\"):\n", + "def grid_search(\n", + " param_grid,\n", + " base_model_params,\n", + " base_training_params,\n", + " store_models: bool = False,\n", + " prefix=\"config\",\n", + "):\n", " results = []\n", "\n", " for i, param_override in enumerate(param_grid):\n", " print(f\"Running configuration {i + 1}/{len(param_grid)}...\")\n", - " \n", + "\n", " # Update base parameters with overrides\n", " model_params = base_model_params.copy()\n", " training_params = base_training_params.copy()\n", - " \n", + "\n", " model_params.update(param_override.get(\"model_params\", {}))\n", " training_params.update(param_override.get(\"training_params\", {}))\n", - " \n", + "\n", " # Train the model with the updated parameters\n", " try:\n", - " model, checkpoints, trainer = train_new_model(model_params=model_params, training_params=training_params, tag=f\"{prefix}_{i + 1}\")\n", - " results.append({\n", - " \"config\": param_override,\n", - " \"model\": model if store_models else None,\n", - " \"checkpoints\": checkpoints,\n", - " \"trainer\": trainer\n", - " })\n", + " model, checkpoints, trainer = train_new_model(\n", + " model_params=model_params,\n", + " training_params=training_params,\n", + " tag=f\"{prefix}_{i + 1}\",\n", + " )\n", + " results.append(\n", + " {\n", + " \"config\": param_override,\n", + " \"model\": model if store_models else None,\n", + " \"checkpoints\": checkpoints,\n", + " \"trainer\": trainer,\n", + " }\n", + " )\n", "\n", " if not store_models:\n", " del model\n", " torch.cuda.empty_cache()\n", - " \n", + "\n", " except Exception as e:\n", " print(f\"Error in configuration {i + 1}: {e}\")\n", - " results.append({\n", - " \"config\": param_override,\n", - " \"error\": str(e)\n", - " })\n", + " results.append({\"config\": param_override, \"error\": str(e)})\n", "\n", " return results" ] @@ -1393,7 +1497,8 @@ ], "source": [ "import torch.multiprocessing as mp\n", - "mp.set_start_method('spawn', force=True)\n", + "\n", + "mp.set_start_method(\"spawn\", force=True)\n", "GRID_SEARCH = True\n", "if GRID_SEARCH:\n", " param_grid = [\n", @@ -1402,7 +1507,9 @@ " {\"model_params\": {}, \"training_params\": {}},\n", " ]\n", "\n", - " grid_results = grid_search(param_grid, model_params, training_params, store_models=True)\n", + " grid_results = grid_search(\n", + " param_grid, model_params, training_params, store_models=True\n", + " )\n", "\n", " # Analyze results\n", " for result in grid_results:\n", @@ -1419,13 +1526,13 @@ "metadata": {}, "outputs": [], "source": [ - "if 'model' in locals():\n", + "if \"model\" in locals():\n", " del model\n", " torch.cuda.empty_cache()\n", "\n", "if not GRID_SEARCH:\n", " print(f\"Training model\")\n", - " model, checkpoints, trainer = train_new_model() # continue_with_model=model)" + " model, checkpoints, trainer = train_new_model() # continue_with_model=model)" ] }, { @@ -1448,7 +1555,7 @@ "\n", "best_result = None\n", "if not GRID_SEARCH:\n", - " print(checkpoints) \n", + " print(checkpoints)\n", " model_at_last_epoch = copy.deepcopy(model)\n", "\n", "else:\n", @@ -1456,7 +1563,9 @@ " model = best_result[\"model\"]\n", " checkpoints = best_result[\"checkpoints\"]\n", " trainer = best_result[\"trainer\"]\n", - " print(f\"Best model found with val_sqrt_error: {best_result['checkpoints']['sqrt_val_loss']}\")\n", + " print(\n", + " f\"Best model found with val_sqrt_error: {best_result['checkpoints']['sqrt_val_loss']}\"\n", + " )\n", " print(\"Parameter overrides for the best model:\")\n", " print(best_result[\"config\"])" ] @@ -1488,6 +1597,7 @@ "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", + "\n", "# Convert numpy arrays to scalars if they are single-element arrays\n", "trainer.history[\"train_error\"] = [\n", " error.item() if isinstance(error, np.ndarray) and error.size == 1 else error\n", @@ -1503,17 +1613,23 @@ "]\n", "\n", "# Create a DataFrame from the tracker dictionary\n", - "tracker_df = pd.DataFrame({\n", - " \"epoch\": trainer.history[\"epoch\"],\n", - " \"train_rmse\": trainer.history[\"sqrt_train_error\"],\n", - " \"val_rmse\": trainer.history[\"sqrt_val_error\"],\n", - " \"lr\": trainer.history[\"lr\"]\n", - "})\n", + "tracker_df = pd.DataFrame(\n", + " {\n", + " \"epoch\": trainer.history[\"epoch\"],\n", + " \"train_rmse\": trainer.history[\"sqrt_train_error\"],\n", + " \"val_rmse\": trainer.history[\"sqrt_val_error\"],\n", + " \"lr\": trainer.history[\"lr\"],\n", + " }\n", + ")\n", "\n", "# Plot the training and validation loss\n", "plt.figure(figsize=(10, 5))\n", - "sns.lineplot(data=tracker_df, x=\"epoch\", y=\"train_rmse\", label=\"Training RMSE\", color=\"blue\")\n", - "sns.lineplot(data=tracker_df, x=\"epoch\", y=\"val_rmse\", label=\"Validation RMSE\", color=\"orange\")\n", + "sns.lineplot(\n", + " data=tracker_df, x=\"epoch\", y=\"train_rmse\", label=\"Training RMSE\", color=\"blue\"\n", + ")\n", + "sns.lineplot(\n", + " data=tracker_df, x=\"epoch\", y=\"val_rmse\", label=\"Validation RMSE\", color=\"orange\"\n", + ")\n", "\n", "# Highlight the epochs where the learning rate changes\n", "previous_lr = None\n", @@ -1522,20 +1638,37 @@ " if current_lr != previous_lr:\n", " epoch = row[\"epoch\"]\n", " val_loss_at_epoch = row[\"val_rmse\"]\n", - " plt.scatter(epoch, val_loss_at_epoch + 0.0001, color=\"black\", marker=\"v\", label=\"LR Change\" if previous_lr is None else \"\")\n", - " plt.text(epoch, val_loss_at_epoch + 0.0002, f\"LR: {current_lr:1.0e}\", color=\"black\", ha=\"center\", fontsize=8)\n", + " plt.scatter(\n", + " epoch,\n", + " val_loss_at_epoch + 0.0001,\n", + " color=\"black\",\n", + " marker=\"v\",\n", + " label=\"LR Change\" if previous_lr is None else \"\",\n", + " )\n", + " plt.text(\n", + " epoch,\n", + " val_loss_at_epoch + 0.0002,\n", + " f\"LR: {current_lr:1.0e}\",\n", + " color=\"black\",\n", + " ha=\"center\",\n", + " fontsize=8,\n", + " )\n", " previous_lr = current_lr\n", "\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"RMSE\")\n", - "#plt.ylim(0, tracker_df[\"val_rmse\"].max() + 0.004)\n", + "# plt.ylim(0, tracker_df[\"val_rmse\"].max() + 0.004)\n", "plt.title(\"Training and Validation Loss Over Epochs\")\n", "plt.legend()\n", "plt.show()\n", "min_train_error = min(trainer.history[\"sqrt_train_error\"])\n", "min_val_error = min(trainer.history[\"sqrt_val_error\"])\n", - "epoch_min_train_error = trainer.history[\"epoch\"][np.argmin(trainer.history[\"sqrt_train_error\"])]\n", - "epoch_min_val_error = trainer.history[\"epoch\"][np.argmin(trainer.history[\"sqrt_val_error\"])]\n", + "epoch_min_train_error = trainer.history[\"epoch\"][\n", + " np.argmin(trainer.history[\"sqrt_train_error\"])\n", + "]\n", + "epoch_min_val_error = trainer.history[\"epoch\"][\n", + " np.argmin(trainer.history[\"sqrt_val_error\"])\n", + "]\n", "print(f\"Minimum Training RMSE: {min_train_error:.5f} (Epoch {epoch_min_train_error})\")\n", "print(f\"Minimum Validation RMSE: {min_val_error:.5f} (Epoch {epoch_min_val_error})\")" ] @@ -1552,36 +1685,48 @@ "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", "df_cas[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", - "\n", - "metadata_key_map16 = {\"collision_energy\": \"CE\", \n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " 'precursor_mode': \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\"\n", - " }\n", + "df_cas[\"Metabolite\"].apply(\n", + " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", + ")\n", + "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", + "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "\n", + "metadata_key_map16 = {\n", + " \"collision_energy\": \"CE\",\n", + " \"instrument\": \"Instrument_type\",\n", + " \"precursor_mz\": \"PRECURSOR_MZ\",\n", + " \"precursor_mode\": \"Precursor_type\",\n", + " \"retention_time\": \"RETENTIONTIME\",\n", + "}\n", "\n", - "df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - "df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder), axis=1)\n", + "df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + ")\n", + "df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder), axis=1\n", + ")\n", "\n", "# Fragmentation\n", "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - "_ = df_cas.apply(lambda x: x[\"Metabolite\"].match_fragments_to_peaks(x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM), axis=1) # Optional: use mz_cut instead\n", + "_ = df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", + " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " ),\n", + " axis=1,\n", + ") # Optional: use mz_cut instead\n", "\n", "#\n", "# CASMI 22\n", "#\n", "\n", - "# TODO Currently errorneous and commented out until further notice \n", + "# TODO Currently errorneous and commented out until further notice\n", "# df_cas22[\"Metabolite\"] = df_cas22[\"SMILES\"].apply(Metabolite)\n", "# df_cas22[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", "# df_cas22[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", "# df_cas22[\"CE\"] = df_cas22.apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1)\n", "\n", - "# metadata_key_map22 = {\"collision_energy\": \"CE\", \n", + "# metadata_key_map22 = {\"collision_energy\": \"CE\",\n", "# \"instrument\": \"Instrument_type\",\n", "# \"precursor_mz\": \"precursor_mz\",\n", "# 'precursor_mode': \"Precursor_type\",\n", @@ -1605,43 +1750,108 @@ "outputs": [], "source": [ "from fiora.MOL.collision_energy import NCE_to_eV\n", - "from fiora.MS.spectral_scores import spectral_cosine, spectral_reflection_cosine, reweighted_dot\n", + "from fiora.MS.spectral_scores import (\n", + " spectral_cosine,\n", + " spectral_reflection_cosine,\n", + " reweighted_dot,\n", + ")\n", "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", + "\n", "def test_cas16(model, df_cas=df_cas, score=\"merged_sqrt_cosine\", return_df=False):\n", - " \n", - " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + "\n", + " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(\n", + " x[\"summary\"], covariate_encoder, rt_encoder\n", + " ),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", "\n", - " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(\n", + " x[\"summary\"], covariate_encoder, rt_encoder\n", + " ),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", "\n", - "\n", - " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1)\n", + " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", + " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", + " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " )\n", " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1)\n", - " df_cas.apply(lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder), axis=1)\n", + " df_cas[\"summary\"] = df_cas.apply(\n", + " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", + " )\n", + " df_cas.apply(\n", + " lambda x: x[\"Metabolite\"].add_metadata(\n", + " x[\"summary\"], covariate_encoder, rt_encoder\n", + " ),\n", + " axis=1,\n", + " )\n", " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", "\n", - " df_cas[\"avg_CE\"] = (df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]) / 3\n", - "\n", - " df_cas[\"merged_peaks\"] = df_cas.apply(lambda x: merge_annotated_spectrum(merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]), x[\"sim_peaks_50\"]) , axis=1)\n", - " df_cas[\"merged_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_sqrt_cosine_wo_prec\"] = df_cas.apply(lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt, remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(x[\"Metabolite\"].metadata[\"precursor_mode\"])), axis=1)\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(lambda x: spectral_reflection_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt), axis=1)\n", - " df_cas[\"merged_steins\"] = df_cas.apply(lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1)\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\"merged_sqrt_cosine\"] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine_wo_prec\"] = df_cas[\"merged_sqrt_cosine_wo_prec\"] # just remember it is merged\n", + " df_cas[\"avg_CE\"] = (\n", + " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " ) / 3\n", + "\n", + " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " lambda x: merge_annotated_spectrum(\n", + " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", + " x[\"sim_peaks_50\"],\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_sqrt_cosine_wo_prec\"] = df_cas.apply(\n", + " lambda x: spectral_cosine(\n", + " x[\"peaks\"],\n", + " x[\"merged_peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(\n", + " x[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " ),\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " lambda x: spectral_reflection_cosine(\n", + " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " ),\n", + " axis=1,\n", + " )\n", + " df_cas[\"merged_steins\"] = df_cas.apply(\n", + " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " )\n", + " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", + " \"merged_sqrt_cosine\"\n", + " ] # just remember it is merged\n", + " df_cas[\"spectral_sqrt_cosine_wo_prec\"] = df_cas[\n", + " \"merged_sqrt_cosine_wo_prec\"\n", + " ] # just remember it is merged\n", "\n", " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", " if hasattr(model, \"rt_module\"):\n", @@ -1650,10 +1860,10 @@ " if hasattr(model, \"ccs_module\"):\n", " df_cas[\"CCS_pred\"] = df_cas[\"CCS_pred_35\"]\n", " df_cas[\"library\"] = \"CASMI-16\"\n", - " \n", + "\n", " if return_df:\n", " return df_cas\n", - " \n", + "\n", " return df_cas[score].values" ] }, @@ -1663,7 +1873,11 @@ "metadata": {}, "outputs": [], "source": [ - "model = FioraModel.load(checkpoints[\"file\"]).to(dev) if not GRID_SEARCH else FioraModel.load(best_result[\"checkpoints\"][\"file\"]).to(dev)\n", + "model = (\n", + " FioraModel.load(checkpoints[\"file\"]).to(dev)\n", + " if not GRID_SEARCH\n", + " else FioraModel.load(best_result[\"checkpoints\"][\"file\"]).to(dev)\n", + ")\n", "df_val = df_train[df_train[\"datasplit\"] == \"validation\"]\n", "\n", "df_val = test_model(model, df_val, return_df=True)\n", @@ -1678,7 +1892,9 @@ "outputs": [], "source": [ "from fiora.MOL.constants import DEFAULT_DALTON\n", - "from fiora.MS.spectral_scores import spectral_cosine \n", + "from fiora.MS.spectral_scores import spectral_cosine\n", + "\n", + "\n", "def construct_explained_peaks(df, tolerance):\n", " explained_peaks_list = []\n", "\n", @@ -1697,23 +1913,58 @@ " explained_intensity.append(intensity)\n", " break # Stop checking once a match is found\n", "\n", - " explained_peaks_list.append({\"mz\": explained_mz, \"intensity\": explained_intensity})\n", + " explained_peaks_list.append(\n", + " {\"mz\": explained_mz, \"intensity\": explained_intensity}\n", + " )\n", "\n", " df[\"explained_peaks\"] = explained_peaks_list\n", " return df\n", "\n", "\n", "df_test = construct_explained_peaks(df_test, DEFAULT_DALTON)\n", - "df_test[\"explained_sqrt_cosine\"] = df_test.apply(lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt), axis=1)\n", - "df_test[\"explained_sqrt_cosine_wo_prec\"] = df_test.apply(lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt, remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz()), axis=1)\n", + "df_test[\"explained_sqrt_cosine\"] = df_test.apply(\n", + " lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + ")\n", + "df_test[\"explained_sqrt_cosine_wo_prec\"] = df_test.apply(\n", + " lambda x: spectral_cosine(\n", + " x[\"explained_peaks\"],\n", + " x[\"peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(),\n", + " ),\n", + " axis=1,\n", + ")\n", "\n", "df_val = construct_explained_peaks(df_val, DEFAULT_DALTON)\n", - "df_val[\"explained_sqrt_cosine\"] = df_val.apply(lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt), axis=1)\n", - "df_val[\"explained_sqrt_cosine_wo_prec\"] = df_val.apply(lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt, remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz()), axis=1)\n", + "df_val[\"explained_sqrt_cosine\"] = df_val.apply(\n", + " lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + ")\n", + "df_val[\"explained_sqrt_cosine_wo_prec\"] = df_val.apply(\n", + " lambda x: spectral_cosine(\n", + " x[\"explained_peaks\"],\n", + " x[\"peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(),\n", + " ),\n", + " axis=1,\n", + ")\n", "\n", "df_cas16 = construct_explained_peaks(df_cas16, DEFAULT_DALTON)\n", - "df_cas16[\"explained_sqrt_cosine\"] = df_cas16.apply(lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt), axis=1)\n", - "df_cas16[\"explained_sqrt_cosine_wo_prec\"] = df_cas16.apply(lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt, remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz()), axis=1)\n" + "df_cas16[\"explained_sqrt_cosine\"] = df_cas16.apply(\n", + " lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt),\n", + " axis=1,\n", + ")\n", + "df_cas16[\"explained_sqrt_cosine_wo_prec\"] = df_cas16.apply(\n", + " lambda x: spectral_cosine(\n", + " x[\"explained_peaks\"],\n", + " x[\"peaks\"],\n", + " transform=np.sqrt,\n", + " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(),\n", + " ),\n", + " axis=1,\n", + ")" ] }, { @@ -1799,30 +2050,31 @@ "avg_func = np.median\n", "\n", "# Dataframe with Val, Test, CASMI-16 rows and fiora model median scores w/o precursor\n", - "df_scores = pd.DataFrame({\n", - " \"dataset\": [\"Validation\", \"Test\", \"CASMI-16\"],\n", - " \n", - " \"spectral_sqrt_cosine\": [\n", - " avg_func(df_val[\"spectral_sqrt_cosine\"]),\n", - " avg_func(df_test[\"spectral_sqrt_cosine\"]),\n", - " avg_func(df_cas16[\"spectral_sqrt_cosine\"].fillna(0))\n", - " ],\n", - " \"explained_sqrt_cosine\": [\n", - " avg_func(df_val[\"explained_sqrt_cosine\"]),\n", - " avg_func(df_test[\"explained_sqrt_cosine\"]),\n", - " avg_func(df_cas16[\"explained_sqrt_cosine\"].fillna(0))\n", - " ],\n", - " \"spectral_sqrt_cosine_wo_prec\": [\n", - " avg_func(df_val[\"spectral_sqrt_cosine_wo_prec\"]),\n", - " avg_func(df_test[\"spectral_sqrt_cosine_wo_prec\"]),\n", - " avg_func(df_cas16[\"spectral_sqrt_cosine_wo_prec\"].fillna(0))\n", - " ],\n", - " \"explained_sqrt_cosine_wo_prec\": [\n", - " avg_func(df_val[\"explained_sqrt_cosine_wo_prec\"]),\n", - " avg_func(df_test[\"explained_sqrt_cosine_wo_prec\"]),\n", - " avg_func(df_cas16[\"explained_sqrt_cosine_wo_prec\"].fillna(0))\n", - " ]\n", - "})\n", + "df_scores = pd.DataFrame(\n", + " {\n", + " \"dataset\": [\"Validation\", \"Test\", \"CASMI-16\"],\n", + " \"spectral_sqrt_cosine\": [\n", + " avg_func(df_val[\"spectral_sqrt_cosine\"]),\n", + " avg_func(df_test[\"spectral_sqrt_cosine\"]),\n", + " avg_func(df_cas16[\"spectral_sqrt_cosine\"].fillna(0)),\n", + " ],\n", + " \"explained_sqrt_cosine\": [\n", + " avg_func(df_val[\"explained_sqrt_cosine\"]),\n", + " avg_func(df_test[\"explained_sqrt_cosine\"]),\n", + " avg_func(df_cas16[\"explained_sqrt_cosine\"].fillna(0)),\n", + " ],\n", + " \"spectral_sqrt_cosine_wo_prec\": [\n", + " avg_func(df_val[\"spectral_sqrt_cosine_wo_prec\"]),\n", + " avg_func(df_test[\"spectral_sqrt_cosine_wo_prec\"]),\n", + " avg_func(df_cas16[\"spectral_sqrt_cosine_wo_prec\"].fillna(0)),\n", + " ],\n", + " \"explained_sqrt_cosine_wo_prec\": [\n", + " avg_func(df_val[\"explained_sqrt_cosine_wo_prec\"]),\n", + " avg_func(df_test[\"explained_sqrt_cosine_wo_prec\"]),\n", + " avg_func(df_cas16[\"explained_sqrt_cosine_wo_prec\"].fillna(0)),\n", + " ],\n", + " }\n", + ")\n", "\n", "df_scores" ] @@ -1847,14 +2099,15 @@ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from fiora.visualization.define_colors import set_light_theme\n", + "\n", "set_light_theme()\n", "\n", "# Define custom colors for the precursor types\n", "custom_palette = {\n", " \"[M+H]+\": \"indianred\", # Nice red tint\n", - " \"[M]+\": \"salmon\", # Nice orange tint\n", + " \"[M]+\": \"salmon\", # Nice orange tint\n", " \"[M-H]-\": \"dodgerblue\", # Nice blue tint\n", - " \"[M]-\": \"lightblue\" # Different blue tint\n", + " \"[M]-\": \"lightblue\", # Different blue tint\n", "}\n", "precursor_types = [\"[M+H]+\", \"[M]+\", \"[M-H]-\", \"[M]-\"]\n", "\n", @@ -1862,30 +2115,70 @@ "fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharey=True)\n", "\n", "# Left subplot: spectral_sqrt_cosine vs spectral_sqrt_cosine with Precursor_type as hue\n", - "sns.boxplot(data=df_val, x=\"Precursor_type\", y=\"spectral_sqrt_cosine\", ax=axes[0], palette=custom_palette, hue=\"Precursor_type\", linewidth=2, legend=False, order=precursor_types)\n", + "sns.boxplot(\n", + " data=df_val,\n", + " x=\"Precursor_type\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " ax=axes[0],\n", + " palette=custom_palette,\n", + " hue=\"Precursor_type\",\n", + " linewidth=2,\n", + " legend=False,\n", + " order=precursor_types,\n", + ")\n", "axes[0].set_title(\"Spectral Cosine\")\n", "axes[0].set_ylabel(\"Cosine Similarity\")\n", "axes[0].set_xlabel(\"Precursor Type\")\n", "\n", "# Add medians above the median line for the left subplot\n", "for i, precursor_type in enumerate(precursor_types):\n", - " median = df_val[df_val[\"Precursor_type\"] == precursor_type][\"spectral_sqrt_cosine\"].median()\n", - " axes[0].text(i, median + 0.01, f\"{median:.2f}\", ha=\"center\", va=\"bottom\", fontsize=11, color=\"black\")\n", + " median = df_val[df_val[\"Precursor_type\"] == precursor_type][\n", + " \"spectral_sqrt_cosine\"\n", + " ].median()\n", + " axes[0].text(\n", + " i,\n", + " median + 0.01,\n", + " f\"{median:.2f}\",\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " fontsize=11,\n", + " color=\"black\",\n", + " )\n", "\n", "# Right subplot: spectral_sqrt_cosine_wo_prec vs spectral_sqrt_cosine_wo_prec with Precursor_type as hue\n", - "sns.boxplot(data=df_val, x=\"Precursor_type\", y=\"spectral_sqrt_cosine_wo_prec\", ax=axes[1], hue=\"Precursor_type\", palette=custom_palette, linewidth=2, legend=True, order=precursor_types)\n", + "sns.boxplot(\n", + " data=df_val,\n", + " x=\"Precursor_type\",\n", + " y=\"spectral_sqrt_cosine_wo_prec\",\n", + " ax=axes[1],\n", + " hue=\"Precursor_type\",\n", + " palette=custom_palette,\n", + " linewidth=2,\n", + " legend=True,\n", + " order=precursor_types,\n", + ")\n", "axes[1].set_title(\"Spectral Cosine (wo Prec)\")\n", "axes[1].set_ylabel(\"Cosine Similarity\")\n", "axes[1].set_xlabel(\"Precursor Type\")\n", "\n", "# Add medians above the median line for the right subplot\n", "for i, precursor_type in enumerate(precursor_types):\n", - " median = df_val[df_val[\"Precursor_type\"] == precursor_type][\"spectral_sqrt_cosine_wo_prec\"].median()\n", - " axes[1].text(i, median + 0.01, f\"{median:.2f}\", ha=\"center\", va=\"bottom\", fontsize=11, color=\"black\")\n", + " median = df_val[df_val[\"Precursor_type\"] == precursor_type][\n", + " \"spectral_sqrt_cosine_wo_prec\"\n", + " ].median()\n", + " axes[1].text(\n", + " i,\n", + " median + 0.01,\n", + " f\"{median:.2f}\",\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " fontsize=11,\n", + " color=\"black\",\n", + " )\n", "\n", "# Adjust layout and show the plot\n", "plt.tight_layout()\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -1905,34 +2198,73 @@ } ], "source": [ - "\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharey=True)\n", "\n", "# Left subplot: spectral_sqrt_cosine vs spectral_sqrt_cosine with Precursor_type as hue\n", - "sns.boxplot(data=df_test, x=\"Precursor_type\", y=\"spectral_sqrt_cosine\", ax=axes[0], palette=custom_palette, hue=\"Precursor_type\", linewidth=2, legend=False, order=precursor_types)\n", + "sns.boxplot(\n", + " data=df_test,\n", + " x=\"Precursor_type\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " ax=axes[0],\n", + " palette=custom_palette,\n", + " hue=\"Precursor_type\",\n", + " linewidth=2,\n", + " legend=False,\n", + " order=precursor_types,\n", + ")\n", "axes[0].set_title(\"Spectral Cosine\")\n", "axes[0].set_ylabel(\"Cosine Similarity\")\n", "axes[0].set_xlabel(\"Precursor Type\")\n", "\n", "# Add medians above the median line for the left subplot\n", "for i, precursor_type in enumerate(precursor_types):\n", - " median = df_test[df_test[\"Precursor_type\"] == precursor_type][\"spectral_sqrt_cosine\"].median()\n", - " axes[0].text(i, median + 0.01, f\"{median:.2f}\", ha=\"center\", va=\"bottom\", fontsize=11, color=\"black\")\n", + " median = df_test[df_test[\"Precursor_type\"] == precursor_type][\n", + " \"spectral_sqrt_cosine\"\n", + " ].median()\n", + " axes[0].text(\n", + " i,\n", + " median + 0.01,\n", + " f\"{median:.2f}\",\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " fontsize=11,\n", + " color=\"black\",\n", + " )\n", "\n", "# Right subplot: spectral_sqrt_cosine_wo_prec vs spectral_sqrt_cosine_wo_prec with Precursor_type as hue\n", - "sns.boxplot(data=df_test, x=\"Precursor_type\", y=\"spectral_sqrt_cosine_wo_prec\", ax=axes[1], hue=\"Precursor_type\", palette=custom_palette, linewidth=2, legend=True, order=precursor_types)\n", + "sns.boxplot(\n", + " data=df_test,\n", + " x=\"Precursor_type\",\n", + " y=\"spectral_sqrt_cosine_wo_prec\",\n", + " ax=axes[1],\n", + " hue=\"Precursor_type\",\n", + " palette=custom_palette,\n", + " linewidth=2,\n", + " legend=True,\n", + " order=precursor_types,\n", + ")\n", "axes[1].set_title(\"Spectral Cosine (wo Prec)\")\n", "axes[1].set_ylabel(\"Cosine Similarity\")\n", "axes[1].set_xlabel(\"Precursor Type\")\n", "\n", "# Add medians above the median line for the right subplot\n", "for i, precursor_type in enumerate(precursor_types):\n", - " median = df_test[df_test[\"Precursor_type\"] == precursor_type][\"spectral_sqrt_cosine_wo_prec\"].median()\n", - " axes[1].text(i, median + 0.01, f\"{median:.2f}\", ha=\"center\", va=\"bottom\", fontsize=11, color=\"black\")\n", + " median = df_test[df_test[\"Precursor_type\"] == precursor_type][\n", + " \"spectral_sqrt_cosine_wo_prec\"\n", + " ].median()\n", + " axes[1].text(\n", + " i,\n", + " median + 0.01,\n", + " f\"{median:.2f}\",\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " fontsize=11,\n", + " color=\"black\",\n", + " )\n", "\n", "# Adjust layout and show the plot\n", "plt.tight_layout()\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -1952,34 +2284,71 @@ } ], "source": [ - "\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharey=True)\n", "\n", "# Left subplot: spectral_sqrt_cosine vs spectral_sqrt_cosine with Precursor_type as hue\n", - "sns.boxplot(data=df_cas16, x=\"Precursor_type\", y=\"spectral_sqrt_cosine\", ax=axes[0], palette=custom_palette, hue=\"Precursor_type\", linewidth=2, legend=False)\n", + "sns.boxplot(\n", + " data=df_cas16,\n", + " x=\"Precursor_type\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " ax=axes[0],\n", + " palette=custom_palette,\n", + " hue=\"Precursor_type\",\n", + " linewidth=2,\n", + " legend=False,\n", + ")\n", "axes[0].set_title(\"Spectral Cosine\")\n", "axes[0].set_ylabel(\"Cosine Similarity\")\n", "axes[0].set_xlabel(\"Precursor Type\")\n", "\n", "# Add medians above the median line for the left subplot\n", "for i, precursor_type in enumerate(df_cas16[\"Precursor_type\"].unique()):\n", - " median = df_cas16[df_cas16[\"Precursor_type\"] == precursor_type][\"spectral_sqrt_cosine\"].median()\n", - " axes[0].text(i, median + 0.01, f\"{median:.2f}\", ha=\"center\", va=\"bottom\", fontsize=11, color=\"black\")\n", + " median = df_cas16[df_cas16[\"Precursor_type\"] == precursor_type][\n", + " \"spectral_sqrt_cosine\"\n", + " ].median()\n", + " axes[0].text(\n", + " i,\n", + " median + 0.01,\n", + " f\"{median:.2f}\",\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " fontsize=11,\n", + " color=\"black\",\n", + " )\n", "\n", "# Right subplot: spectral_sqrt_cosine_wo_prec vs spectral_sqrt_cosine_wo_prec with Precursor_type as hue\n", - "sns.boxplot(data=df_cas16, x=\"Precursor_type\", y=\"spectral_sqrt_cosine_wo_prec\", ax=axes[1], hue=\"Precursor_type\", palette=custom_palette, linewidth=2, legend=True)\n", + "sns.boxplot(\n", + " data=df_cas16,\n", + " x=\"Precursor_type\",\n", + " y=\"spectral_sqrt_cosine_wo_prec\",\n", + " ax=axes[1],\n", + " hue=\"Precursor_type\",\n", + " palette=custom_palette,\n", + " linewidth=2,\n", + " legend=True,\n", + ")\n", "axes[1].set_title(\"Spectral Cosine (wo Prec)\")\n", "axes[1].set_ylabel(\"Cosine Similarity\")\n", "axes[1].set_xlabel(\"Precursor Type\")\n", "\n", "# Add medians above the median line for the right subplot\n", "for i, precursor_type in enumerate(df_cas16[\"Precursor_type\"].unique()):\n", - " median = df_cas16[df_cas16[\"Precursor_type\"] == precursor_type][\"spectral_sqrt_cosine_wo_prec\"].median()\n", - " axes[1].text(i, median + 0.01, f\"{median:.2f}\", ha=\"center\", va=\"bottom\", fontsize=11, color=\"black\")\n", + " median = df_cas16[df_cas16[\"Precursor_type\"] == precursor_type][\n", + " \"spectral_sqrt_cosine_wo_prec\"\n", + " ].median()\n", + " axes[1].text(\n", + " i,\n", + " median + 0.01,\n", + " f\"{median:.2f}\",\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " fontsize=11,\n", + " color=\"black\",\n", + " )\n", "\n", "# Adjust layout and show the plot\n", "plt.tight_layout()\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -2010,20 +2379,38 @@ "bin_counts = df_val[\"mol_weight_bin\"].value_counts(sort=False)\n", "\n", "# Create the figure and subplots\n", - "fig, axes = plt.subplots(2, 1, figsize=(12, 8), gridspec_kw={\"height_ratios\": [1, 3]}, sharex=True)\n", + "fig, axes = plt.subplots(\n", + " 2, 1, figsize=(12, 8), gridspec_kw={\"height_ratios\": [1, 3]}, sharex=True\n", + ")\n", "\n", "# Generate the viridis color palette\n", "viridis_palette = sns.color_palette(\"viridis\", len(bin_counts))\n", "\n", "# Top subplot: Bar plot for counts in each bin\n", - "sns.barplot(x=bin_counts.index.astype(str), y=bin_counts.values, ax=axes[0], palette=viridis_palette, hue=bin_counts.index.astype(str), legend=False, dodge=False)\n", + "sns.barplot(\n", + " x=bin_counts.index.astype(str),\n", + " y=bin_counts.values,\n", + " ax=axes[0],\n", + " palette=viridis_palette,\n", + " hue=bin_counts.index.astype(str),\n", + " legend=False,\n", + " dodge=False,\n", + ")\n", "axes[0].set_title(\"Counts per Molecular Weight Bin\")\n", "axes[0].set_ylabel(\"Count\")\n", "axes[0].set_xlabel(\"\")\n", "axes[0].tick_params(axis=\"x\", rotation=45)\n", "\n", "# Bottom subplot: Boxplot for spectral sqrt cosine values for each bin\n", - "sns.boxplot(data=df_val, x=\"mol_weight_bin\", y=\"spectral_sqrt_cosine\", ax=axes[1], palette=\"viridis\", hue=\"mol_weight_bin\", legend=False)\n", + "sns.boxplot(\n", + " data=df_val,\n", + " x=\"mol_weight_bin\",\n", + " y=\"spectral_sqrt_cosine\",\n", + " ax=axes[1],\n", + " palette=\"viridis\",\n", + " hue=\"mol_weight_bin\",\n", + " legend=False,\n", + ")\n", "axes[1].set_title(\"Spectral Sqrt Cosine vs Molecular Weight\")\n", "axes[1].set_ylabel(\"Spectral Sqrt Cosine\")\n", "axes[1].set_xlabel(\"Molecular Weight Bin\")\n", @@ -2031,7 +2418,7 @@ "\n", "# Adjust layout and show the plot\n", "plt.tight_layout()\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -2100,8 +2487,12 @@ "model.model_params[\"version_number\"] = \"1.0.0\"\n", "\n", "model.model_params[\"training_library\"] = \"MSnLib v7\"\n", - "model.model_params[\"comment\"] = \"This is an open-source FIORA model released on GitHub trained on the MSnLib v7.\"\n", - "model.model_params[\"disclaimer\"] = \"No prediction software is perfect. Use with caution.\"" + "model.model_params[\"comment\"] = (\n", + " \"This is an open-source FIORA model released on GitHub trained on the MSnLib v7.\"\n", + ")\n", + "model.model_params[\"disclaimer\"] = (\n", + " \"No prediction software is perfect. Use with caution.\"\n", + ")" ] }, { @@ -2123,7 +2514,9 @@ "if save_model:\n", " depth = model.model_params[\"depth\"]\n", " print(f\"Saving model with depth {depth}\")\n", - " MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v1.0.0_OS_depth{depth}_Sep25_x4.pt\"\n", + " MODEL_PATH = (\n", + " f\"{home}/data/metabolites/pretrained_models/v1.0.0_OS_depth{depth}_Sep25_x4.pt\"\n", + " )\n", " model.save(MODEL_PATH)\n", " print(f\"Saved to {MODEL_PATH}\")" ] @@ -2157,9 +2550,13 @@ " continue\n", " model_i = FioraModel.load(ckpt_path).to(dev)\n", " model_i.eval()\n", - " dfi = test_model(model_i, df_val.copy(deep=True), return_df=True) # adds spectral_sqrt_cosine\n", + " dfi = test_model(\n", + " model_i, df_val.copy(deep=True), return_df=True\n", + " ) # adds spectral_sqrt_cosine\n", " depth = i - 1 # file index -> model depth\n", - " tmp = dfi[[\"group_id\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\"]].copy()\n", + " tmp = dfi[\n", + " [\"group_id\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\"]\n", + " ].copy()\n", " tmp[\"depth\"] = depth\n", " pred_rows.append(tmp)\n", " del model_i\n", @@ -2190,21 +2587,31 @@ "import matplotlib.pyplot as plt\n", "\n", "# Build long-form data from raw preds_long\n", - "plot_long = (\n", - " preds_long.assign(depth=preds_long[\"depth\"].astype(int))\n", - " .melt(id_vars=[\"group_id\", \"depth\"],\n", - " value_vars=[\"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\"],\n", - " var_name=\"metric\", value_name=\"score\")\n", + "plot_long = preds_long.assign(depth=preds_long[\"depth\"].astype(int)).melt(\n", + " id_vars=[\"group_id\", \"depth\"],\n", + " value_vars=[\"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\"],\n", + " var_name=\"metric\",\n", + " value_name=\"score\",\n", + ")\n", + "plot_long[\"metric\"] = plot_long[\"metric\"].map(\n", + " {\n", + " \"spectral_sqrt_cosine\": \"Cosine Similarity\",\n", + " \"spectral_sqrt_cosine_wo_prec\": \"Cosine Similarity w/o Precursor\",\n", + " }\n", ")\n", - "plot_long[\"metric\"] = plot_long[\"metric\"].map({\n", - " \"spectral_sqrt_cosine\": \"Cosine Similarity\",\n", - " \"spectral_sqrt_cosine_wo_prec\": \"Cosine Similarity w/o Precursor\"\n", - "})\n", "\n", "plt.figure(figsize=(10, 5))\n", "sns.pointplot(\n", - " data=plot_long, x=\"depth\", y=\"score\", hue=\"metric\",\n", - " estimator=np.median, errorbar=\"sd\", n_boot=1000, markers=\"o\", dodge=0.2, capsize=0.2\n", + " data=plot_long,\n", + " x=\"depth\",\n", + " y=\"score\",\n", + " hue=\"metric\",\n", + " estimator=np.median,\n", + " errorbar=\"sd\",\n", + " n_boot=1000,\n", + " markers=\"o\",\n", + " dodge=0.2,\n", + " capsize=0.2,\n", ")\n", "plt.xticks(sorted(plot_long[\"depth\"].unique().tolist()))\n", "plt.xlabel(\"Depth\")\n", @@ -2272,11 +2679,13 @@ "metadata": {}, "outputs": [], "source": [ - "dev=\"cuda:0\"\n", - "mymy = GNNCompiler.load(MODEL_PATH) # f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_2.pt\"\n", - "#mymy.load_state_dict(torch.load(f\"{home}/data/metabolites/pretrained_models/test.pt\"))\n", + "dev = \"cuda:0\"\n", + "mymy = GNNCompiler.load(\n", + " MODEL_PATH\n", + ") # f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_2.pt\"\n", + "# mymy.load_state_dict(torch.load(f\"{home}/data/metabolites/pretrained_models/test.pt\"))\n", "mymy.eval()\n", - "mymy = mymy.to(dev)\n" + "mymy = mymy.to(dev)" ] }, { @@ -2326,7 +2735,8 @@ "outputs": [], "source": [ "import json\n", - "with open(MODEL_PATH.replace(\".pt\", \"_params.json\"), 'r') as fp:\n", + "\n", + "with open(MODEL_PATH.replace(\".pt\", \"_params.json\"), \"r\") as fp:\n", " p = json.load(fp)\n", "hh = GNNCompiler(p)\n", "hh.load_state_dict(torch.load(MODEL_PATH.replace(\".pt\", \"_state.pt\")))\n", @@ -2403,19 +2813,24 @@ "source": [ "## prepare output for for CFM-ID\n", "import os\n", + "\n", "save_df = False\n", "cfm_directory = f\"{home}/data/metabolites/cfm-id/\"\n", "name = \"test_split_negative_solutions_cfm.txt\"\n", "df_cfm = df_test[[\"group_id\", \"SMILES\", \"Precursor_type\"]]\n", - "df_n = df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"].drop_duplicates(subset='group_id', keep='first')\n", - "df_p = df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"].drop_duplicates(subset='group_id', keep='first')\n", + "df_n = df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"].drop_duplicates(\n", + " subset=\"group_id\", keep=\"first\"\n", + ")\n", + "df_p = df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"].drop_duplicates(\n", + " subset=\"group_id\", keep=\"first\"\n", + ")\n", "\n", "print(df_n.head())\n", "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", " df_n[[\"group_id\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")\n", - " \n", + "\n", " name = name.replace(\"negative\", \"positive\")\n", " file = os.path.join(cfm_directory, name)\n", " df_p[[\"group_id\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")" diff --git a/pyproject.toml b/pyproject.toml index 3198267..8f30322 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,11 @@ classifiers = [ ] dependencies = [ "numpy", + "pandas", "seaborn", "torch", "torch_geometric>=2.6,<2.7", + "torchmetrics", "dill", "rdkit", "treelib", @@ -27,9 +29,13 @@ dependencies = [ "ipython>=8", ] +[project.scripts] +fiora-predict = "fiora.cli.predict:main" +fiora-train = "fiora.cli.train:main" + [tool.setuptools] include-package-data = true -script-files = ["scripts/fiora-predict"] # like setup.py `scripts=[...]` (not console entry points) +script-files = ["scripts/fiora-predict", "scripts/fiora-train"] # like setup.py `scripts=[...]` (not console entry points) [tool.setuptools.packages.find] include = ["fiora", "fiora.*", "models"] diff --git a/scripts/fiora-predict b/scripts/fiora-predict index b1fbe8c..d418e19 100644 --- a/scripts/fiora-predict +++ b/scripts/fiora-predict @@ -1,201 +1,6 @@ #! /usr/bin/env python -import pandas as pd -import os -import warnings -from rdkit import RDLogger -RDLogger.DisableLog("rdApp.*") -warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") - -import argparse -import importlib.resources -import fiora.IO.mgfWriter as mgfWriter -import fiora.IO.mspWriter as mspWriter - -from fiora.GNN.FioraModel import FioraModel -from fiora.MS.SimulationFramework import SimulationFramework -from fiora.MOL.Metabolite import Metabolite -from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder -from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder -from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(prog='fiora-predict', - description='Fiora is an in silico fragmentation framework, which predicts peaks and simulates tandem mass spectra including features such as retention time and collision cross sections. Use this script for spectrum predictions with a (pre-)trained model.', - epilog='Disclaimer:\nNo prediction software is perfect. Use with caution.') - parser.add_argument("-i", "--input", help="Input file containing molecular structures (SMILES/InChi) and metadata (.csv file)", type=str, required=True) - parser.add_argument("-o", "--output", help="Output file path (.mgf/.msp file)", type=str, required=True) - parser.add_argument("--model", help="Path to prediction model (.pt file)", type=str, default="default") - parser.add_argument("--dev", help="Device to the model. For example cuda:0 for GPU number 0.", type=str, default="cpu") - parser.add_argument("--min_prob", help="Minimum peak probability to be recorded in the spectrum", type=float, default=0.001) - - parser.add_argument('--rt', action=argparse.BooleanOptionalAction, help="Predict retention time", default=False) - parser.add_argument('--ccs', action=argparse.BooleanOptionalAction, help="Predict collison cross section", default=False) - parser.add_argument("--annotation", action=argparse.BooleanOptionalAction, help="Annotate predicted peaks with SMILES strings", default=False) - parser.add_argument("--debug", action=argparse.BooleanOptionalAction, help="Receive debug information", default=False) - args = parser.parse_args() - - return args - -def update_args_with_model_params(args: argparse.Namespace, model_params: dict) -> argparse.Namespace: - - if "rt_supported" in model_params.keys(): - if not model_params["rt_supported"] and args.rt: - print("Warning: RT prediction is not support by the model. Overwriting user argument to --no-rt.\n") - args.rt = False - if "ccs_supported" in model_params.keys(): - if not model_params["ccs_supported"] and args.ccs: - print("Warning: CCS prediction is not support by the model. Overwriting user argument to --no-ccs.\n") - args.ccs = False - return args - -def print_model_messages(model_params: dict) -> None: - if "version" in model_params.keys(): - print(f"\n-----Model-----") - print(model_params["version"]) - print(f"---------------") - if "disclaimer" in model_params.keys(): - dis_msg = model_params["disclaimer"] - print(f"\nDisclaimer: {dis_msg}") - -metadata_key_map = { - "name": "Name", - "collision_energy": "CE", - "instrument": "Instrument_type", - "precursor_mode": "Precursor_type", - } - -def safe_metabolite_creation(smiles): - try: - return Metabolite(smiles) - except (AssertionError, ValueError): - return None - -def build_metabolites(df: pd.DataFrame, model_params: dict): - - # Set feature encoder up - CE_upper_limit = 100.0 - weight_upper_limit = 1000.0 - - model_setup_feature_sets = None - if "setup_features_categorical_set" in model_params.keys(): - model_setup_feature_sets = model_params["setup_features_categorical_set"] - - node_encoder = AtomFeatureEncoder(feature_list=["symbol", "num_hydrogen", "ring_type"]) - bond_encoder = BondFeatureEncoder(feature_list=["bond_type", "ring_type"]) - if model_params["version_number"] == "0.1.0": - covariate_features = ["collision_energy", "molecular_weight", "precursor_mode", "instrument"] - else: - covariate_features =["collision_energy", "molecular_weight", "precursor_mode", "instrument", "element_composition"] - setup_encoder = CovariateFeatureEncoder(feature_list=covariate_features, sets_overwrite=model_setup_feature_sets) - rt_encoder = CovariateFeatureEncoder(feature_list=["molecular_weight", "precursor_mode", "instrument"], sets_overwrite=model_setup_feature_sets) - - setup_encoder.normalize_features["collision_energy"]["max"] = CE_upper_limit - setup_encoder.normalize_features["molecular_weight"]["max"] = weight_upper_limit - rt_encoder.normalize_features["molecular_weight"]["max"] = weight_upper_limit - - # Convert SMILES to Metabolites and create structure graphs and fragmentation trees - df["Metabolite"] = df["SMILES"].apply(safe_metabolite_creation) - invalid_df = df[df["Metabolite"].isna()][["Name", "SMILES"]] - df.dropna(subset=["Metabolite"], inplace=True) - - df["Metabolite"].apply(lambda x: x.create_molecular_structure_graph()) - df["Metabolite"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)) - - # Map covariate features to dedicated format and encode - df["summary"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1) - df.apply(lambda x: x["Metabolite"].add_metadata(x["summary"], setup_encoder, rt_encoder), axis=1) - - # Fragment compounds - df["Metabolite"].apply(lambda x: x.fragment_MOL(depth=1)) - #df.apply(lambda x: x["Metabolite"].match_fragments_to_peaks(x["peaks"]["mz"], x["peaks"]["intensity"], tolerance=x["ppm_peak_tolerance"]), axis=1) - - return df, invalid_df - -def prepare_output(args, df, model): - df["peaks"] = df["sim_peaks"] - df["Formula"] = df["Metabolite"].apply(lambda x: x.Formula) - df["Precursor_MZ"] = df["Metabolite"].apply(lambda x: x.get_theoretical_precursor_mz(ion_type=x.metadata["precursor_mode"])) - - # Rename certain columns - if "RT_pred" in df.columns: - df["RETENTIONTIME"] = df["RT_pred"] - df["PRECURSOR_MZ"] = df["Precursor_MZ"] - df["FORMULA"] = df["Formula"] - if "CCS_pred" in df.columns: - df["CCS"] = df["CCS_pred"] - version = model.model_params["version"] if "version" in model.model_params else "(pre-release version v0.0.0)" - df["Comment"] = f"\"In silico generated spectrum by {version}\"" - df["COMMENT"] = df["Comment"] - - # Write output file - if args.output.endswith(".msp"): - df["Collision_energy"] = df["CE"] - headers=["Name", "SMILES", "Formula", "Precursor_MZ", "Precursor_type", "Instrument_type", "Collision_energy"] - if args.rt: headers.append("RETENTIONTIME") - if args.ccs: headers.append("CCS") - headers.append("Comment") - mspWriter.write_msp(df, path=args.output, write_header=True, headers=headers, annotation=args.annotation) - elif args.output.endswith(".mgf"): - headers = ["TITLE", "SMILES", "FORMULA", "PRECURSOR_MZ", "PRECURSORTYPE", "COLLISIONENERGY", "INSTRUMENTTYPE"] - if args.rt: headers.append("RETENTIONTIME") - if args.ccs: headers.append("CCS") - headers.append("COMMENT") - mgfWriter.write_mgf(df, path=args.output, write_header=True, headers=headers, header_map={"TITLE": "Name", "PRECURSORTYPE": "Precursor_type", "INSTRUMENTTYPE": "Instrument_type", "COLLISIONENERGY": "CE"}, annotation=args.annotation) - else: - print(f"Warning: Unknown output format {args.output}. Writing results to {args.output}.mgf instead.") - args.output = args.output + ".mgf" - headers = ["TITLE", "SMILES", "FORMULA", "PRECURSORTYPE", "COLLISIONENERGY", "INSTRUMENTTYPE"] - if args.rt: headers.append("RETENTIONTIME") - if args.ccs: headers.append("CCS") - headers.append("COMMENT") - mgfWriter.write_mgf(df, path=args.output, write_header=True, headers=headers, header_map={"TITLE": "Name", "PRECURSORTYPE": "Precursor_type", "INSTRUMENTTYPE": "Instrument_type", "COLLISIONENERGY": "CE"}, annotation=args.annotation) - - -def main(): - args = parse_args() - if args.debug: - print(f'Running fiora prediction with the following parameters: {args}\n') - - # Load model - if args.model == "default": - with importlib.resources.path('models', 'fiora_OS_v1.0.0.pt') as model_path: - args.model = str(model_path) - - try: - model = FioraModel.load_from_state_dict(args.model) - except Exception as e: - print(f"Error: Failed loading from model from state dict. Caused by: {e}.") - exit(1) - - print_model_messages(model.model_params) - args = update_args_with_model_params(args, model.model_params) - - - model.eval() - model = model.to(args.dev) - - # Set up Fiora - fiora = SimulationFramework(None, dev=args.dev) - - # Load the data - df = pd.read_csv(args.input) - - # Construct molecular structure graphs and fragmentation trees - df, invalid_df = build_metabolites(df, model.model_params) - if invalid_df.shape[0] > 0: - if args.debug: - print("Warning: The following input SMILES could not be read or formatted:") - print(invalid_df) - else: - print("Warning: Some SMILES could not be read or formatted. Run with --debug flag for more information.") - - # Simulate compound fragmentation - df = fiora.simulate_all(df, model, groundtruth=False, min_intensity=args.min_prob) - - # Prepare Output - prepare_output(args, df, model) - print(f"Finished prediction. Exported MS/MS spectra to {args.output}.") +from fiora.cli.predict import main if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/fiora-train b/scripts/fiora-train new file mode 100644 index 0000000..9c2b904 --- /dev/null +++ b/scripts/fiora-train @@ -0,0 +1,6 @@ +#! /usr/bin/env python +from fiora.cli.train import main + + +if __name__ == "__main__": + main() diff --git a/tests/test_fiora_predict.py b/tests/test_fiora_predict.py index fcc341d..74c4291 100644 --- a/tests/test_fiora_predict.py +++ b/tests/test_fiora_predict.py @@ -13,19 +13,22 @@ ## Importing fiora predict (from executable) from importlib.util import spec_from_loader, module_from_spec -from importlib.machinery import SourceFileLoader -spec = spec_from_loader("fiora-predict", SourceFileLoader("fiora-predict", os.getcwd() + "/scripts/fiora-predict")) +from importlib.machinery import SourceFileLoader + +spec = spec_from_loader( + "fiora-predict", + SourceFileLoader("fiora-predict", os.getcwd() + "/scripts/fiora-predict"), +) fiora_predict = module_from_spec(spec) spec.loader.exec_module(fiora_predict) -sys.modules['fiora_predict'] = fiora_predict +sys.modules["fiora_predict"] = fiora_predict class TestFioraPredict(unittest.TestCase): - @classmethod def setUpClass(cls): cls.temp_path = "temp_spec.mgf" - + @classmethod def tearDownClass(cls): if os.path.exists(cls.temp_path): @@ -37,9 +40,8 @@ def test_missing_args(self): with self.assertRaises(SystemExit) as cm, contextlib.redirect_stderr(f): fiora_predict.main() self.assertEqual(cm.exception.code, 2) - self.assertTrue(f.getvalue().startswith("usage:")) - - + self.assertTrue(f.getvalue().startswith("usage:")) + def test_help(self): f = io.StringIO() with patch("sys.argv", ["main", "-h"]): @@ -47,19 +49,21 @@ def test_help(self): fiora_predict.main() self.assertEqual(cm.exception.code, 0) self.assertTrue(f.getvalue().startswith("usage:")) - self.assertTrue("-h, --help" in f.getvalue()) - self.assertTrue("show this help message and exit" in f.getvalue()) - + self.assertTrue("-h, --help" in f.getvalue()) + self.assertTrue("show this help message and exit" in f.getvalue()) def test_dummy(self): - self.assertEqual('fiora'.upper(), 'FIORA') + self.assertEqual("fiora".upper(), "FIORA") def test_model_cpu(self): f = io.StringIO() - with patch("sys.argv", ["main", "-i", "examples/example_input.csv", "-o", self.temp_path]): + with patch( + "sys.argv", + ["main", "-i", "examples/example_input.csv", "-o", self.temp_path], + ): with contextlib.redirect_stdout(f): fiora_predict.main() - self.assertIn("Finished prediction.", f.getvalue()) + self.assertIn("Finished prediction.", f.getvalue()) self.assertTrue(os.path.exists(self.temp_path)) def test_model_output_integrity(self): @@ -67,8 +71,14 @@ def test_model_output_integrity(self): df_expected = mgfReader.read(expected_output, as_df=True) df_new = mgfReader.read(self.temp_path, as_df=True) - - columns = ["TITLE", "SMILES", "PRECURSORTYPE", "COLLISIONENERGY", "INSTRUMENTTYPE"] + + columns = [ + "TITLE", + "SMILES", + "PRECURSORTYPE", + "COLLISIONENERGY", + "INSTRUMENTTYPE", + ] self.assertDictEqual(df_expected[columns].to_dict(), df_new[columns].to_dict()) for i, data in df_expected.iterrows(): peaks_expected = data["peaks"] @@ -76,17 +86,20 @@ def test_model_output_integrity(self): cosine = spectral_cosine(peaks_expected, peaks_new, transform=np.sqrt) self.assertGreater(cosine, 0.99) -if __name__ == '__main__': + +if __name__ == "__main__": # unittest.main() suite = unittest.TestSuite() - - suite.addTests([ - TestFioraPredict('test_dummy'), - TestFioraPredict('test_help'), - TestFioraPredict('test_missing_args'), - TestFioraPredict('test_model_cpu'), - TestFioraPredict('test_model_output_integrity'), - ]) - + + suite.addTests( + [ + TestFioraPredict("test_dummy"), + TestFioraPredict("test_help"), + TestFioraPredict("test_missing_args"), + TestFioraPredict("test_model_cpu"), + TestFioraPredict("test_model_output_integrity"), + ] + ) + runner = unittest.TextTestRunner() - runner.run(suite) \ No newline at end of file + runner.run(suite) From 4b19915abd896e262ef07681e522f5da656b7f33 Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Mon, 23 Feb 2026 15:27:56 +0100 Subject: [PATCH 02/15] added missing dependency on scikit-learn --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 8f30322..a473617 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "treelib", "spectrum_utils", "ipython>=8", + "scikit-learn", ] [project.scripts] From 1a741509f42f98b847af5aa3516977095ad3cb11 Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Thu, 5 Mar 2026 17:29:04 +0100 Subject: [PATCH 03/15] finalized train and preprocessing scripts --- README.md | 33 ++ constraints.txt | 1 + fiora/GNN/SpectralTrainer.py | 111 +++- fiora/cli/predict.py | 2 +- fiora/cli/train.py | 72 ++- .../resources/models}/__init__.py | 0 .../resources/models}/fiora_OS_v0.1.0.pt | Bin .../models}/fiora_OS_v0.1.0_params.json | 0 .../models}/fiora_OS_v0.1.0_state.pt | Bin .../resources/models}/fiora_OS_v1.0.0.pt | Bin .../models}/fiora_OS_v1.0.0_params.json | 0 .../models}/fiora_OS_v1.0.0_state.pt | Bin notebooks/live_predict.ipynb | 2 +- notebooks/test_model.ipynb | 8 +- pyproject.toml | 5 +- requirements.txt | 1 + resources/data/msnlib/Makefile | 17 + resources/data/msnlib/download_msnlib.py | 202 +++++++ resources/data/msnlib/preprocess_msnlib.py | 497 ++++++++++++++++++ 19 files changed, 905 insertions(+), 46 deletions(-) rename {models => fiora/resources/models}/__init__.py (100%) rename {models => fiora/resources/models}/fiora_OS_v0.1.0.pt (100%) rename {models => fiora/resources/models}/fiora_OS_v0.1.0_params.json (100%) rename {models => fiora/resources/models}/fiora_OS_v0.1.0_state.pt (100%) rename {models => fiora/resources/models}/fiora_OS_v1.0.0.pt (100%) rename {models => fiora/resources/models}/fiora_OS_v1.0.0_params.json (100%) rename {models => fiora/resources/models}/fiora_OS_v1.0.0_state.pt (100%) create mode 100644 resources/data/msnlib/Makefile create mode 100644 resources/data/msnlib/download_msnlib.py create mode 100644 resources/data/msnlib/preprocess_msnlib.py diff --git a/README.md b/README.md index 7d5b018..1647164 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,39 @@ Run the fiora-predict from within this directory By default, an open-source model is selected automatically, and predictions typically complete within a few seconds. For faster performance, specify a GPU device using the `--dev` option (e.g., `--dev cuda:0`). The output file (e.g., examples/example_spec.mgf) can be compared with the [expected results](examples/expected_output.mgf) to verify model accuracy. This verification is automatically performed by running pytest (as described above). +### Models and Resources + +Default model checkpoints are packaged under `fiora/resources/models` (Python package: `fiora.resources.models`). The CLI uses these automatically when `--model default` is selected. + +Scripts for downloading and preprocessing MSnLib are provided in `resources/data/msnlib` (`download_msnlib.py` and `preprocess_msnlib.py`). + +The downloader defaults to MSnLib v7 on Zenodo (`https://zenodo.org/records/16984129`) and accepts both direct file URLs and Zenodo record URLs. For Zenodo record URLs it downloads all files matching `*_ms2.mgf` by default. + +```bash +python resources/data/msnlib/download_msnlib.py +python resources/data/msnlib/preprocess_msnlib.py +``` + +Use `--record-pattern` to select a different subset, e.g. `--record-pattern "*_pos_*.mgf"`. + +### MSnLib Training Parity (Notebook vs CLI) + +The training notebooks override categorical feature sets for MSnLib: + +- `instrument`: `["HCD"]` +- `precursor_mode`: `["[M+H]+", "[M-H]-", "[M]+", "[M]-"]` + +To match notebook training results when using `fiora-train`, pass the same overrides: + +```bash +fiora-train \ + -i resources/data/msnlib/library.csv \ + -o checkpoints/fiora.pt \ + --device cuda:0 \ + --instruments HCD \ + --precursor-modes "[M+H]+,[M-H]-,[M]+,[M]-" +``` + ## The Algorithm FIORA has been developed as a computational tool to predict bond cleavages that occur in the MS/MS fragmentation process and estimate the probabilities of resulting fragment ions. To that end, FIORA utilizes graph neural networks to learn local molecular neighborhoods around bonds, combined with edge prediction to simulate bond dissociation. The prediction determines which fragment (left or right of the bond cleavage, with up to four possible hydrogen losses) retains the charge and which becomes the neutral loss. The figure below illustrates an example fragmentation prediction for a single bond. diff --git a/constraints.txt b/constraints.txt index e8aa544..7332e87 100644 --- a/constraints.txt +++ b/constraints.txt @@ -121,6 +121,7 @@ pytz==2025.2 PyYAML==6.0.2 pyzmq==26.4.0 rdkit==2024.9.6 +regex==2024.11.6 referencing==0.36.2 requests==2.32.3 rfc3339-validator==0.1.4 diff --git a/fiora/GNN/SpectralTrainer.py b/fiora/GNN/SpectralTrainer.py index 4a63d76..4250d61 100644 --- a/fiora/GNN/SpectralTrainer.py +++ b/fiora/GNN/SpectralTrainer.py @@ -15,6 +15,8 @@ GNN Trainer """ +TQDM_DATA_THRESHOLD = 10000 + class SpectralTrainer(Trainer): def __init__( @@ -65,6 +67,38 @@ def __init__( geom_loader.DataLoader if library == "geometric" else DataLoader ) + @staticmethod + def _to_float(value): + if isinstance(value, torch.Tensor): + return float(value.detach().cpu().item()) + return float(value) + + @staticmethod + def _build_progress_iterator(dataloader, enabled=False, desc=""): + if not enabled: + return dataloader + try: + from tqdm.auto import tqdm + + return tqdm(dataloader, total=len(dataloader), desc=desc, leave=False) + except Exception: + return dataloader + + @staticmethod + def _format_metric(stats): + if "mse" in stats: + rmse = torch.sqrt(stats["mse"]) + return "rmse", float(rmse.detach().cpu().item()) + if "mae" in stats: + return "mae", float(stats["mae"].detach().cpu().item()) + if "acc" in stats: + return "acc", float(stats["acc"].detach().cpu().item()) + key = next(iter(stats.keys())) + val = stats[key] + if isinstance(val, torch.Tensor): + val = float(val.detach().cpu().item()) + return key, float(val) + def _training_loop( self, model, @@ -76,12 +110,17 @@ def _training_loop( with_RT=False, with_CCS=False, rt_metric=False, - title="", + show_progress=False, + progress_desc="Train", ): training_loss = 0 metrics.increment() + num_batches = 0 - for id, batch in enumerate(dataloader): + iterator = self._build_progress_iterator( + dataloader, enabled=show_progress, desc=progress_desc + ) + for _, batch in enumerate(iterator): # Feed forward model.train() @@ -138,21 +177,13 @@ def _training_loop( optimizer.zero_grad() loss.backward() optimizer.step() + training_loss += self._to_float(loss) + num_batches += 1 # End of training cycle: Evaluation stats = metrics.compute() - training_loss /= len(dataloader) - - if self.problem_type == "classification": - print( - f"{title} Training Accuracy: {stats['acc']:>.3f} (Loss per batch: {'NOT TRACKED'})", - end="\r", - ) - else: - print( - f"{title} RMSE: {torch.sqrt(stats['mse']):>.4f}", end="\r" - ) # MSE: {stats["mse"]:>.3f}; MAE: {stats["mae"]:>.3f} - return stats + training_loss /= max(num_batches, 1) + return stats, training_loss def _validation_loop( self, @@ -165,11 +196,17 @@ def _validation_loop( with_CCS=False, rt_metric=False, mask_name=None, - title="Validation", + show_progress=False, + progress_desc="Validation", ): metrics.increment() + validation_loss = 0 + num_batches = 0 with torch.no_grad(): - for id, batch in enumerate(dataloader): + iterator = self._build_progress_iterator( + dataloader, enabled=show_progress, desc=progress_desc + ) + for _, batch in enumerate(iterator): model.eval() y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS) if mask_name: @@ -187,6 +224,11 @@ def _validation_loop( metrics.update( y_pred["fragment_probs"], batch[self.y_tag], **kwargs ) + batch_loss = loss_fn( + y_pred["fragment_probs"], batch[self.y_tag], **kwargs + ) + validation_loss += self._to_float(batch_loss) + num_batches += 1 if rt_metric: metrics( y_pred["rt"][batch["retention_mask"]], @@ -201,8 +243,11 @@ def _validation_loop( # End of Validation cycle stats = metrics.compute() - print(f"\t{title} RMSE: {torch.sqrt(stats['mse']):>.4f}") - return stats + if num_batches > 0: + validation_loss /= num_batches + else: + validation_loss = float("nan") + return stats, validation_loss # Training function def train( @@ -247,21 +292,26 @@ def train( using_weighted_loss_func = isinstance(loss_fn, WeightedMSELoss) | isinstance( loss_fn, WeightedMAELoss ) + show_train_progress = len(self.training_data) > TQDM_DATA_THRESHOLD + show_val_progress = ( + (not self.only_training) and (len(self.validation_data) > TQDM_DATA_THRESHOLD) + ) # Main loop for e in range(epochs): # Training - train_stats = self._training_loop( + train_stats, train_loss = self._training_loop( model, training_loader, optimizer, loss_fn, self.metrics["train"], - title=f"Epoch {e + 1}/{epochs}: ", with_weights=using_weighted_loss_func, with_RT=with_RT, with_CCS=with_CCS, rt_metric=rt_metric, + show_progress=show_train_progress, + progress_desc=f"Train {e + 1}/{epochs}", ) # Validation @@ -269,7 +319,7 @@ def train( (e + 1) % val_every_n_epochs == 0 ) if is_val_cycle: - val_stats = self._validation_loop( + val_stats, val_loss = self._validation_loop( model, validation_loader, loss_fn, @@ -281,8 +331,25 @@ def train( with_CCS=with_CCS, rt_metric=rt_metric, mask_name=mask_name if use_validation_mask else None, - title="Masked Validation" if use_validation_mask else "Validation", + show_progress=show_val_progress, + progress_desc=f"Val {e + 1}/{epochs}", ) + else: + val_stats, val_loss = None, float("nan") + + train_metric_name, train_metric_value = self._format_metric(train_stats) + if val_stats is not None: + val_metric_name, val_metric_value = self._format_metric(val_stats) + val_metric_str = f"val_{val_metric_name}: {val_metric_value:.4f}" + else: + val_metric_str = "val_metric: n/a" + val_loss_str = f"{val_loss:.4f}" if not np.isnan(val_loss) else "n/a" + print( + f"Epoch {e + 1}/{epochs} - loss: {train_loss:.4f} - " + f"val_loss: {val_loss_str} - " + f"train_{train_metric_name}: {train_metric_value:.4f} - " + f"{val_metric_str}" + ) # End of epoch: Advance scheduler if scheduler: diff --git a/fiora/cli/predict.py b/fiora/cli/predict.py index c230f80..71cab0b 100644 --- a/fiora/cli/predict.py +++ b/fiora/cli/predict.py @@ -316,7 +316,7 @@ def main() -> None: # Load model if args.model == "default": with resources.as_file( - resources.files("models").joinpath("fiora_OS_v1.0.0.pt") + resources.files("fiora.resources.models").joinpath("fiora_OS_v1.0.0.pt") ) as model_path: args.model = str(model_path) diff --git a/fiora/cli/train.py b/fiora/cli/train.py index 7950e99..44ee96e 100644 --- a/fiora/cli/train.py +++ b/fiora/cli/train.py @@ -3,6 +3,7 @@ import ast import json import os +import re import warnings import numpy as np @@ -180,9 +181,21 @@ def _parse_dict(val): return val if val is None or (isinstance(val, float) and np.isnan(val)): return None - text = str(val) + text = str(val).strip() + if not text: + return None + try: + # Handles canonical JSON and JSON with NaN/Infinity tokens. + return json.loads(text) + except Exception: + pass + # Fallback for python-literal style dict strings. + norm = re.sub(r"\b(?:NaN|nan)\b", "None", text) + norm = re.sub(r"\b(?:Infinity|inf)\b", "1e309", norm) + norm = re.sub(r"\b(?:-Infinity|-inf)\b", "-1e309", norm) try: - return ast.literal_eval(text.replace("nan", "None")) + parsed = ast.literal_eval(norm) + return parsed if isinstance(parsed, dict) else None except Exception: return None @@ -203,11 +216,17 @@ def _safe_metabolite(smiles: str): def _build_summary_from_columns(row, metadata_key_map): summary = {} - for key, col in metadata_key_map.items(): - if col in row.index: - value = row[col] - if value is not None and not (isinstance(value, float) and np.isnan(value)): - summary[key] = value + for key, cols in metadata_key_map.items(): + if not isinstance(cols, (list, tuple)): + cols = [cols] + for col in cols: + if col in row.index: + value = row[col] + if value is not None and not ( + isinstance(value, float) and np.isnan(value) + ): + summary[key] = value + break return summary @@ -299,13 +318,13 @@ def main() -> None: rt_encoder.normalize_features["molecular_weight"]["max"] = args.weight_upper_limit metadata_key_map = { - "name": "Name", - "collision_energy": "CE", - "instrument": "Instrument_type", - "precursor_mode": "Precursor_type", - "precursor_mz": "PrecursorMZ", - "retention_time": "RETENTIONTIME", - "ccs": "CCS", + "name": ["Name", "NAME", "Title", "TITLE"], + "collision_energy": ["CE", "COLLISION_ENERGY", "CollisionEnergy"], + "instrument": ["Instrument_type", "instrument", "INSTRUMENT_TYPE"], + "precursor_mode": ["Precursor_type", "ADDUCT", "PRECURSORTYPE"], + "precursor_mz": ["PrecursorMZ", "PEPMASS", "PRECURSORMZ"], + "retention_time": ["RETENTIONTIME", "RTINSECONDS", "retention_time"], + "ccs": ["CCS", "ccs"], } # Build metabolites @@ -441,8 +460,30 @@ def main() -> None: print(f"Prepared training/validation with {len(geo_data)} data points") # Model params + default_params = { + "param_tag": "default", + "gnn_type": "RGCNConv", + "depth": 10, + "hidden_dimension": 300, + "residual_connections": False, + "layer_stacking": True, + "embedding_aggregation": "concat", + "embedding_dimension": 300, + "subgraph_features": True, + "pooling_func": "max", + "layer_norm": True, + "dense_layers": 2, + "dense_dim": 500, + "input_dropout": 0.25, + "latent_dropout": 0.25, + "prepare_additional_layers": False, + "rt_supported": False, + "ccs_supported": False, + "version": "x.x.x", + } base_params = _load_model_params(args.model_params) - model_params = dict(base_params) + model_params = dict(default_params) + model_params.update(base_params) model_params.update( { "node_feature_layout": node_encoder.feature_numbers, @@ -509,7 +550,6 @@ def main() -> None: patience=args.scheduler_patience, factor=args.scheduler_factor, mode="min", - verbose=True, ) output_path = args.output diff --git a/models/__init__.py b/fiora/resources/models/__init__.py similarity index 100% rename from models/__init__.py rename to fiora/resources/models/__init__.py diff --git a/models/fiora_OS_v0.1.0.pt b/fiora/resources/models/fiora_OS_v0.1.0.pt similarity index 100% rename from models/fiora_OS_v0.1.0.pt rename to fiora/resources/models/fiora_OS_v0.1.0.pt diff --git a/models/fiora_OS_v0.1.0_params.json b/fiora/resources/models/fiora_OS_v0.1.0_params.json similarity index 100% rename from models/fiora_OS_v0.1.0_params.json rename to fiora/resources/models/fiora_OS_v0.1.0_params.json diff --git a/models/fiora_OS_v0.1.0_state.pt b/fiora/resources/models/fiora_OS_v0.1.0_state.pt similarity index 100% rename from models/fiora_OS_v0.1.0_state.pt rename to fiora/resources/models/fiora_OS_v0.1.0_state.pt diff --git a/models/fiora_OS_v1.0.0.pt b/fiora/resources/models/fiora_OS_v1.0.0.pt similarity index 100% rename from models/fiora_OS_v1.0.0.pt rename to fiora/resources/models/fiora_OS_v1.0.0.pt diff --git a/models/fiora_OS_v1.0.0_params.json b/fiora/resources/models/fiora_OS_v1.0.0_params.json similarity index 100% rename from models/fiora_OS_v1.0.0_params.json rename to fiora/resources/models/fiora_OS_v1.0.0_params.json diff --git a/models/fiora_OS_v1.0.0_state.pt b/fiora/resources/models/fiora_OS_v1.0.0_state.pt similarity index 100% rename from models/fiora_OS_v1.0.0_state.pt rename to fiora/resources/models/fiora_OS_v1.0.0_state.pt diff --git a/notebooks/live_predict.ipynb b/notebooks/live_predict.ipynb index f3a887e..705eb81 100644 --- a/notebooks/live_predict.ipynb +++ b/notebooks/live_predict.ipynb @@ -73,7 +73,7 @@ "outputs": [], "source": [ "depth = 6\n", - "MODEL_PATH = f\"../models/fiora_OS_v0.1.0.pt\"\n", + "MODEL_PATH = f\"../resources/models/fiora_OS_v0.1.0.pt\"\n", "\n", "try:\n", " model = GNNCompiler.load(MODEL_PATH)\n", diff --git a/notebooks/test_model.ipynb b/notebooks/test_model.ipynb index d9c570a..cc5d3b4 100644 --- a/notebooks/test_model.ipynb +++ b/notebooks/test_model.ipynb @@ -82,13 +82,13 @@ "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/pre_package/v0.0.1_2_OS_depth{depth}_June24+CCS+RT.pt\" # OS model (first try)\n", "\n", "# NEW AND SHINY\n", - "MODEL_PATH = f\"../models/fiora_OS_v1.0.0.pt\" # Release version\n", + "MODEL_PATH = f\"../resources/models/fiora_OS_v1.0.0.pt\" # Release version\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v1.0.0_OS_depth10_Sep25_x4.pt\"\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_depth{depth}_Aug24_sqrt+CCS+RT_drop3.pt\" # New sqrt model (improved) | Note: drop3 uses dropout reduction while training RT, CCS\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_OS_depth{depth}_Aug24_sqrt_4.pt\" # or Aug24_sqrt_x are new OS models\n", "\n", "# v: str = \"0.1.0\"\n", - "# MODEL_PATH = f\"../models/fiora_OS_v{v}.pt\" # Release version\n", + "# MODEL_PATH = f\"../resources/models/fiora_OS_v{v}.pt\" # Release version\n", "\n", "from fiora.GNN.FioraModel import FioraModel\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", @@ -17384,7 +17384,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.save(\"../models/fiora_OS_v1.0.0.pt\")" + "model.save(\"../resources/models/fiora_OS_v1.0.0.pt\")" ] }, { @@ -18054,7 +18054,7 @@ "metadata": {}, "outputs": [], "source": [ - "# NEW_MODEL_PATH = f\"../models/fiora_OS_v{v}.pt\"\n", + "# NEW_MODEL_PATH = f\"../resources/models/fiora_OS_v{v}.pt\"\n", "# model.save(NEW_MODEL_PATH)" ] }, diff --git a/pyproject.toml b/pyproject.toml index a473617..06b3a42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "torchmetrics", "dill", "rdkit", + "regex", "treelib", "spectrum_utils", "ipython>=8", @@ -39,10 +40,10 @@ include-package-data = true script-files = ["scripts/fiora-predict", "scripts/fiora-train"] # like setup.py `scripts=[...]` (not console entry points) [tool.setuptools.packages.find] -include = ["fiora", "fiora.*", "models"] +include = ["fiora", "fiora.*"] [tool.setuptools.package-data] -models = [ +"fiora.resources.models" = [ "fiora_OS_v0.1.0.pt", "fiora_OS_v0.1.0_state.pt", "fiora_OS_v0.1.0_params.json", diff --git a/requirements.txt b/requirements.txt index e8aa544..7332e87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -121,6 +121,7 @@ pytz==2025.2 PyYAML==6.0.2 pyzmq==26.4.0 rdkit==2024.9.6 +regex==2024.11.6 referencing==0.36.2 requests==2.32.3 rfc3339-validator==0.1.4 diff --git a/resources/data/msnlib/Makefile b/resources/data/msnlib/Makefile new file mode 100644 index 0000000..b773375 --- /dev/null +++ b/resources/data/msnlib/Makefile @@ -0,0 +1,17 @@ +.PHONY: all download preprocess + +PYTHON ?= python3 +DOWNLOAD_ARGS ?= +PREPROCESS_ARGS ?= + +SCRIPT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) +DOWNLOAD_SCRIPT := $(SCRIPT_DIR)download_msnlib.py +PREPROCESS_SCRIPT := $(SCRIPT_DIR)preprocess_msnlib.py + +all: download preprocess + +download: + $(PYTHON) $(DOWNLOAD_SCRIPT) $(DOWNLOAD_ARGS) + +preprocess: + $(PYTHON) $(PREPROCESS_SCRIPT) $(PREPROCESS_ARGS) diff --git a/resources/data/msnlib/download_msnlib.py b/resources/data/msnlib/download_msnlib.py new file mode 100644 index 0000000..632b51c --- /dev/null +++ b/resources/data/msnlib/download_msnlib.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +"""Download MSnLib archives or files. + +Examples: + python3 resources/data/msnlib/download_msnlib.py \ + --url "https://example.org/msnlib.mgf" \ + --output-dir resources/data/msnlib/raw + + python3 resources/data/msnlib/download_msnlib.py \ + --url "https://example.org/msnlib.zip" \ + --output-dir resources/data/msnlib/raw \ + --extract +""" + +from __future__ import annotations + +import argparse +import fnmatch +import gzip +import json +import re +import shutil +import tarfile +import zipfile +from pathlib import Path +from urllib.parse import urlparse +from urllib.request import urlopen, urlretrieve + + +DEFAULT_URL = "https://zenodo.org/records/16984129" +DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent / "raw" + + +def _default_filename(url: str) -> str: + name = Path(urlparse(url).path).name + return name or "msnlib_download" + + +def _parse_zenodo_record_id(url: str) -> str | None: + m = re.search(r"zenodo\.org/(?:records|record)/(\d+)", url) + if not m: + return None + return m.group(1) + + +def _resolve_zenodo_record( + url: str, + *, + filename_override: str | None, + record_pattern: str | None, + record_max_files: int, +) -> list[tuple[str, str]]: + record_id = _parse_zenodo_record_id(url) + if record_id is None or "/files/" in urlparse(url).path: + resolved_url = url + filename = filename_override or _default_filename(resolved_url) + return [(resolved_url, filename)] + + api_url = f"https://zenodo.org/api/records/{record_id}" + with urlopen(api_url) as resp: + payload = json.loads(resp.read().decode("utf-8")) + + files = payload.get("files", []) + if not files: + raise RuntimeError(f"No downloadable files found in Zenodo record {record_id}") + + selected = [] + for file_item in files: + key = str(file_item.get("key", "")) + if not record_pattern or fnmatch.fnmatch(key, record_pattern): + selected.append(file_item) + + if not selected: + raise RuntimeError( + f"No files in Zenodo record {record_id} match pattern {record_pattern!r}" + ) + + selected = sorted(selected, key=lambda x: str(x.get("key", ""))) + if record_max_files > 0: + selected = selected[:record_max_files] + + if filename_override is not None and len(selected) != 1: + raise RuntimeError( + "--filename can only be used when exactly one file is selected" + ) + + resolved = [] + for file_item in selected: + links = file_item.get("links", {}) + resolved_url = links.get("self") + if not resolved_url: + continue + key = str(file_item.get("key") or _default_filename(resolved_url)) + filename = filename_override or key + resolved.append((resolved_url, filename)) + + if not resolved: + raise RuntimeError(f"Could not resolve any download URLs for record {record_id}") + return resolved + + +def _extract_archive(path: Path, output_dir: Path) -> None: + lower = path.name.lower() + if lower.endswith(".zip"): + with zipfile.ZipFile(path, "r") as zf: + zf.extractall(output_dir) + return + if lower.endswith(".tar.gz") or lower.endswith(".tgz"): + with tarfile.open(path, "r:gz") as tf: + tf.extractall(output_dir) + return + if lower.endswith(".gz") and not lower.endswith(".tar.gz"): + out_path = output_dir / path.with_suffix("").name + with gzip.open(path, "rb") as src, open(out_path, "wb") as dst: + shutil.copyfileobj(src, dst) + return + raise ValueError(f"Unsupported archive format: {path}") + + +def _is_archive(path: Path) -> bool: + lower = path.name.lower() + return ( + lower.endswith(".zip") + or lower.endswith(".tar.gz") + or lower.endswith(".tgz") + or (lower.endswith(".gz") and not lower.endswith(".tar.gz")) + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Download MSnLib files.") + parser.add_argument( + "--url", + default=DEFAULT_URL, + help=( + "URL to download. Zenodo record URLs are supported and resolved to a file " + "(default: MSnLib v7 Zenodo record)." + ), + ) + parser.add_argument( + "--output-dir", + default=str(DEFAULT_OUTPUT_DIR), + help="Directory to store downloads/extracted files.", + ) + parser.add_argument( + "--filename", + default=None, + help="Optional filename override for the downloaded file.", + ) + parser.add_argument( + "--record-pattern", + default="*_ms2.mgf", + help=( + "Glob pattern for file keys when --url is a Zenodo record " + "(default: *_ms2.mgf)." + ), + ) + parser.add_argument( + "--record-max-files", + type=int, + default=0, + help=( + "Limit number of selected files from a Zenodo record. " + "0 means no limit (default)." + ), + ) + parser.add_argument( + "--extract", + action=argparse.BooleanOptionalAction, + default=True, + help="Extract archives after download (default: true).", + ) + + args = parser.parse_args() + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + resolved_downloads = _resolve_zenodo_record( + args.url, + filename_override=args.filename, + record_pattern=args.record_pattern, + record_max_files=args.record_max_files, + ) + print(f"Selected {len(resolved_downloads)} file(s) from {args.url}") + + for resolved_url, filename in resolved_downloads: + dest = output_dir / filename + print(f"Downloading {resolved_url} -> {dest}") + urlretrieve(resolved_url, dest) + + if args.extract: + if _is_archive(dest): + extract_dir = output_dir / dest.stem + extract_dir.mkdir(parents=True, exist_ok=True) + _extract_archive(dest, extract_dir) + print(f"Extracted to {extract_dir}") + else: + print(f"No extraction performed for {dest.name} (not an archive).") + + +if __name__ == "__main__": + main() diff --git a/resources/data/msnlib/preprocess_msnlib.py b/resources/data/msnlib/preprocess_msnlib.py new file mode 100644 index 0000000..7eeebcc --- /dev/null +++ b/resources/data/msnlib/preprocess_msnlib.py @@ -0,0 +1,497 @@ +#!/usr/bin/env python3 +"""Preprocess MSnLib spectra with full parity to lib_loader/msnlib_loader.ipynb.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import numpy as np +import pandas as pd +from rdkit import Chem, RDLogger +from rdkit.Chem import Descriptors +from sklearn.model_selection import train_test_split + +from fiora.IO import mgfReader +from fiora.IO.LibraryLoader import LibraryLoader +from fiora.MOL import constants as mol_constants +from fiora.MOL.Metabolite import Metabolite +from fiora.MOL.MetaboliteIndex import MetaboliteIndex +from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder + +RDLogger.DisableLog("rdApp.*") + +BASE_DIR = Path(__file__).resolve().parent +DEFAULT_OUTPUT = BASE_DIR / "library.csv" +DEFAULT_ALLOWED_PRECURSOR_MODES = ["[M+H]+", "[M-H]-", "[M]+", "[M]-"] +DEFAULT_SPECTYPES = ["SINGLE_BEST_SCAN", "SAME_ENERGY", "SINGLE_SCAN"] + + +def _log(msg: str, verbose: bool) -> None: + if verbose: + print(f"[preprocess_msnlib] {msg}", flush=True) + + +def _iter_progress(iterable, *, total: int, desc: str, enabled: bool): + if not enabled: + return iterable + try: + from tqdm.auto import tqdm # Optional dependency + except Exception: + return iterable + return tqdm(iterable, total=total, desc=desc) + + +def _load_msnlib_dir(path: Path) -> pd.DataFrame: + dfs = [] + for filename in sorted(path.iterdir()): + if not filename.name.endswith("ms2.mgf"): + continue + df = pd.DataFrame(mgfReader.read(str(filename))) + df["file"] = filename.name + df["lib"] = "MSnLib" + parts = filename.name.split("_") + df["origin"] = parts[1] if len(parts) > 1 else "" + dfs.append(df) + if not dfs: + raise SystemExit(f"No ms2.mgf files found in {path}") + df = pd.concat(dfs, ignore_index=True) + df.reset_index(inplace=True) + return df + + +def _compute_ce_steps(series: pd.Series, delim: str) -> pd.Series: + def _parse(val): + if pd.isna(val): + return [] + text = str(val).strip() + if "[" in text and "]" in text: + text = text.strip("[]") + parts = [p for p in text.split(delim) if p] + vals = [] + for p in parts: + try: + vals.append(float(p)) + except ValueError: + continue + if vals: + return vals + # fallback single float + try: + return [float(text)] + except ValueError: + return [] + + return series.apply(_parse) + + +def _reweight_groups(df: pd.DataFrame) -> pd.DataFrame: + df["num_per_group"] = df["group_id"].map(df["group_id"].value_counts()) + df["loss_weight"] = 1.0 / df["num_per_group"] + return df + + +def _apply_hard_soft_filters(df: pd.DataFrame) -> pd.DataFrame: + hard_filters = {"min_peaks": 2, "min_coverage": 0.5, "max_precursor_intensity": 0.9} + soft_filters = { + "desired_peaks": 4, + "desired_coverage": 0.75, + "desired_peak_percentage": 0.5, + } + drop_indices = [] + for i, data in df.iterrows(): + m = data["Metabolite"] + hard_pass = True + if m.match_stats["num_peak_matches_filtered"] < hard_filters["min_peaks"]: + hard_pass = False + if m.match_stats["coverage"] < hard_filters["min_coverage"]: + hard_pass = False + if m.match_stats["precursor_prob"] > hard_filters["max_precursor_intensity"]: + hard_pass = False + if not hard_pass: + drop_indices.append(i) + continue + + soft_pass = False + if m.match_stats["num_peak_matches_filtered"] >= soft_filters["desired_peaks"]: + soft_pass = True + if ( + m.match_stats["percent_peak_matches_filtered"] + >= soft_filters["desired_peak_percentage"] + ): + soft_pass = True + if m.match_stats["coverage"] >= soft_filters["desired_coverage"]: + soft_pass = True + if not soft_pass: + drop_indices.append(i) + + if drop_indices: + df = df.drop(drop_indices) + return df + + +def _assign_reference_splits( + df: pd.DataFrame, + reference_path: str, + casmi16_path: str | None, + casmi22_path: str | None, + casmi16t_path: str | None, + seed: int, +) -> pd.DataFrame: + L = LibraryLoader() + df_merged = L.load_from_csv(reference_path) + other_dfs = { + "train": df_merged[df_merged["dataset"] == "training"].drop_duplicates( + subset=["group_id"] + ), + "val": df_merged[df_merged["dataset"] == "validation"].drop_duplicates( + subset=["group_id"] + ), + "test": df_merged[df_merged["dataset"] == "test"].drop_duplicates( + subset=["group_id"] + ), + } + if casmi16_path: + other_dfs["test"] = pd.concat( + [other_dfs["test"], pd.read_csv(casmi16_path, index_col=[0])] + ) + if casmi16t_path: + other_dfs["test"] = pd.concat( + [other_dfs["test"], pd.read_csv(casmi16t_path, index_col=[0])] + ) + if casmi22_path: + other_dfs["test"] = pd.concat( + [other_dfs["test"], pd.read_csv(casmi22_path, index_col=[0])] + ) + other_dfs["test"] = other_dfs["test"].drop_duplicates(subset=["SMILES"]) + + lookup_table = {"train": set(), "val": set(), "test": set()} + for key, df_x in other_dfs.items(): + df_x["Metabolite"] = df_x["SMILES"].apply(Metabolite) + for _, data in df_x.iterrows(): + m = data["Metabolite"] + lookup_table[key].add((m.ExactMolWeight, m.morganFingerCountOnes)) + + train, val, test = [], [], [] + for gid in df["group_id"].unique(): + m = df[df["group_id"] == gid].iloc[0]["Metabolite"] + fast_id = (m.ExactMolWeight, m.morganFingerCountOnes) + found_match = False + if fast_id in lookup_table["train"]: + for _, data in other_dfs["train"].iterrows(): + if m == data["Metabolite"]: + train.append(gid) + found_match = True + break + if not found_match and fast_id in lookup_table["val"]: + for _, data in other_dfs["val"].iterrows(): + if m == data["Metabolite"]: + val.append(gid) + found_match = True + break + if not found_match and fast_id in lookup_table["test"]: + for _, data in other_dfs["test"].iterrows(): + if m == data["Metabolite"]: + test.append(gid) + break + + keys = np.unique(df["group_id"].astype(int)) + mask = ~np.isin(keys, train + val + test) + unassigned_keys = keys[mask] + desired_split_size = int(len(keys) * 0.1) + test_size_remaining = desired_split_size - len(test) + val_size_remaining = desired_split_size - len(val) + + test_new_frac = test_size_remaining / len(unassigned_keys) if len(unassigned_keys) else 0 + val_new_frac = val_size_remaining / len(unassigned_keys) if len(unassigned_keys) else 0 + if len(unassigned_keys): + temp_keys, test_keys = train_test_split( + unassigned_keys, test_size=test_new_frac, random_state=seed + ) + adjusted_val_size = val_new_frac / (1 - test_new_frac) if (1 - test_new_frac) else 0 + train_keys, val_keys = train_test_split( + temp_keys, test_size=adjusted_val_size, random_state=seed + ) + train = np.concatenate((np.array(train), train_keys)) + val = np.concatenate((np.array(val), val_keys)) + test = np.concatenate((np.array(test), test_keys)) + + df["dataset"] = df["group_id"].apply( + lambda x: "training" if x in train else "validation" if x in val else "test" + ) + df["datasplit"] = df["dataset"] + return df + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Preprocess MSnLib spectra (full parity with msnlib_loader.ipynb)." + ) + parser.add_argument( + "--msnlib-dir", + default=str(BASE_DIR / "raw"), + help="Directory with MSnLib ms2.mgf files (default: ./raw).", + ) + parser.add_argument("--version", default="v7", help="MSnLib version (v5/v7).") + parser.add_argument( + "--filter-spectype", + action=argparse.BooleanOptionalAction, + default=True, + help="Filter spectra by SPECTYPE (default: true).", + ) + parser.add_argument( + "--allowed-spectypes", + default=",".join(DEFAULT_SPECTYPES), + help="Comma-separated list of spectypes to keep.", + ) + parser.add_argument("--ppm-num", type=int, default=10) + parser.add_argument("--ce-upper-limit", type=float, default=100.0) + parser.add_argument("--weight-upper-limit", type=float, default=1000.0) + parser.add_argument( + "--allowed-precursor-modes", + default=",".join(DEFAULT_ALLOWED_PRECURSOR_MODES), + help="Comma-separated precursor modes to keep.", + ) + parser.add_argument( + "--reference-splits", + default=None, + help="Path to reference datasplits CSV (e.g., datasplits_Jan24.csv).", + ) + parser.add_argument("--casmi16", default=None, help="Path to CASMI-16 CSV.") + parser.add_argument("--casmi22", default=None, help="Path to CASMI-22 CSV.") + parser.add_argument("--casmi16t", default=None, help="Path to CASMI-16T CSV.") + parser.add_argument( + "--assign-datasplit", + action=argparse.BooleanOptionalAction, + default=True, + help="Assign train/val/test splits if no reference provided (default: true).", + ) + parser.add_argument("--train-frac", type=float, default=0.8) + parser.add_argument("--val-frac", type=float, default=0.1) + parser.add_argument("--test-frac", type=float, default=0.1) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--verbose", + action=argparse.BooleanOptionalAction, + default=True, + help="Print stage-level progress messages (default: true).", + ) + parser.add_argument( + "--progress", + action=argparse.BooleanOptionalAction, + default=True, + help="Show tqdm progress bars for heavy loops (default: true).", + ) + parser.add_argument( + "--output", + default=str(DEFAULT_OUTPUT), + help="Output CSV path (default: ./library.csv).", + ) + + args = parser.parse_args() + msnlib_dir = Path(args.msnlib_dir) + + _log(f"Loading MSnLib files from {msnlib_dir}", args.verbose) + df = _load_msnlib_dir(msnlib_dir) + _log(f"Loaded {len(df)} raw spectra rows.", args.verbose) + delim = ", " if args.version == "v5" else "," + df["CE_steps"] = _compute_ce_steps(df["COLLISION_ENERGY"], delim) + df["Num_steps"] = df["CE_steps"].apply(len) + df["CE"] = df["CE_steps"].apply(lambda x: sum(x) / len(x) if x else np.nan) + + if args.filter_spectype: + allowed_spectypes = [s.strip() for s in args.allowed_spectypes.split(",") if s.strip()] + before = len(df) + df = df[df["SPECTYPE"].isin(allowed_spectypes)] + _log( + f"SPECTYPE filter kept {len(df)}/{before} rows: {allowed_spectypes}", + args.verbose, + ) + + df["peaks"] = df["peaks"].apply(lambda p: p if isinstance(p, dict) else None) + before = len(df) + df = df[df["peaks"].notna()].copy() + _log(f"Rows with valid peaks: {len(df)}/{before}", args.verbose) + + tolerance = args.ppm_num * mol_constants.PPM + df["PPM_num"] = args.ppm_num + df["ppm_peak_tolerance"] = tolerance + + _log("Constructing Metabolite objects...", args.verbose) + df["Metabolite"] = [ + Metabolite(smiles) + for smiles in _iter_progress( + df["SMILES"], total=len(df), desc="Metabolites", enabled=args.progress + ) + ] + _log("Building molecular structure graphs...", args.verbose) + for m in _iter_progress( + df["Metabolite"], total=len(df), desc="Build graphs", enabled=args.progress + ): + m.create_molecular_structure_graph() + _log("Computing graph attributes...", args.verbose) + for m in _iter_progress( + df["Metabolite"], total=len(df), desc="Graph attrs", enabled=args.progress + ): + m.compute_graph_attributes(memory_safe=False) + + mindex = MetaboliteIndex() + _log("Indexing metabolites and creating fragmentation trees...", args.verbose) + mindex.index_metabolites(df["Metabolite"]) + h_plus = Chem.MolFromSmiles("[H+]") + mol_constants.ADDUCT_WEIGHTS.update( + { + "[M+2H]-": Descriptors.ExactMolWt(h_plus) + + 1 * Descriptors.ExactMolWt(Chem.MolFromSmiles("[H]")), + "[M+3H]-": Descriptors.ExactMolWt(h_plus) + + 2 * Descriptors.ExactMolWt(Chem.MolFromSmiles("[H]")), + } + ) + mindex.create_fragmentation_trees() + mindex.add_fragmentation_trees_to_metabolite_list( + df["Metabolite"], graph_mismatch_policy="recompute" + ) + + df["group_id"] = df["Metabolite"].apply(lambda x: x.get_id()) + df = _reweight_groups(df) + + _log("Matching fragments to peaks...", args.verbose) + for metabolite, peaks in _iter_progress( + zip(df["Metabolite"], df["peaks"]), + total=len(df), + desc="Match fragments", + enabled=args.progress, + ): + metabolite.match_fragments_to_peaks( + peaks["mz"], + peaks["intensity"], + tolerance=tolerance, + match_stats_only=True, + ) + + df["PEPMASS"] = pd.to_numeric(df["PEPMASS"], errors="coerce") + df["RTINSECONDS"] = pd.to_numeric(df["RTINSECONDS"], errors="coerce") + df["ionization"] = "ESI" + df["instrument"] = "HCD" + df["Precursor_type"] = df["ADDUCT"] + + metadata_key_map = { + "name": "NAME", + "collision_energy": "CE", + "instrument": "instrument", + "ionization": "ionization", + "precursor_mz": "PEPMASS", + "precursor_mode": "Precursor_type", + "retention_time": "RTINSECONDS", + "ce_steps": "CE_steps", + } + + setup_encoder = CovariateFeatureEncoder( + feature_list=[ + "collision_energy", + "molecular_weight", + "precursor_mode", + "instrument", + ] + ) + rt_encoder = CovariateFeatureEncoder( + feature_list=["molecular_weight", "precursor_mode", "instrument"] + ) + setup_encoder.normalize_features["collision_energy"]["max"] = args.ce_upper_limit + setup_encoder.normalize_features["molecular_weight"]["max"] = args.weight_upper_limit + rt_encoder.normalize_features["molecular_weight"]["max"] = args.weight_upper_limit + + df["summary"] = df.apply( + lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1 + ) + df.apply( + lambda x: x["Metabolite"].add_metadata(x["summary"], setup_encoder, rt_encoder), + axis=1, + ) + + allowed_precursors = [ + x.strip() for x in args.allowed_precursor_modes.split(",") if x.strip() + ] + before = len(df) + df = df[df["ADDUCT"].isin(allowed_precursors)] + _log( + f"Precursor mode filter kept {len(df)}/{before} rows: {allowed_precursors}", + args.verbose, + ) + + correct_energy = df["Metabolite"].apply( + lambda x: (x.metadata["collision_energy"] <= args.ce_upper_limit) + and (x.metadata["collision_energy"] > 1) + ) + before = len(df) + df = df[correct_energy] + _log(f"Collision energy filter kept {len(df)}/{before} rows.", args.verbose) + + correct_weight = df["Metabolite"].apply( + lambda x: x.metadata["molecular_weight"] <= args.weight_upper_limit + ) + before = len(df) + df = df[correct_weight] + _log(f"Molecular weight filter kept {len(df)}/{before} rows.", args.verbose) + + before = len(df) + df = _apply_hard_soft_filters(df) + _log(f"Peak-match quality filters kept {len(df)}/{before} rows.", args.verbose) + + if args.reference_splits: + _log("Assigning datasplits from reference files...", args.verbose) + df = _assign_reference_splits( + df, + args.reference_splits, + args.casmi16, + args.casmi22, + args.casmi16t, + args.seed, + ) + elif args.assign_datasplit: + _log("Assigning random datasplits...", args.verbose) + group_ids = df["group_id"].unique().tolist() + rng = np.random.default_rng(args.seed) + rng.shuffle(group_ids) + n = len(group_ids) + n_train = int(n * args.train_frac) + n_val = int(n * args.val_frac) + n_test = int(n * args.test_frac) + if n_train + n_val + n_test > n: + n_test = max(0, n - (n_train + n_val)) + train_ids = set(group_ids[:n_train]) + val_ids = set(group_ids[n_train : n_train + n_val]) + test_ids = set(group_ids[n_train + n_val : n_train + n_val + n_test]) + + def _split_label(gid): + if gid in train_ids: + return "training" + if gid in val_ids: + return "validation" + if gid in test_ids: + return "test" + return "training" + + df["datasplit"] = df["group_id"].apply(_split_label) + + if "datasplit" in df.columns: + counts = df["datasplit"].value_counts().to_dict() + _log(f"Split counts: {counts}", args.verbose) + + df = _reweight_groups(df) + + if "Metabolite" in df.columns: + df = df.drop(columns=["Metabolite"]) + + for col in ["peaks", "summary"]: + if col in df.columns: + df[col] = df[col].apply(lambda v: json.dumps(v) if isinstance(v, dict) else v) + + _log(f"Writing output to {args.output}", args.verbose) + df.to_csv(args.output, index=False) + print(f"Wrote {len(df)} rows to {args.output}") + + +if __name__ == "__main__": + main() From ae08dad20d46210cb0a6c4b89d862189915c173b Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Fri, 6 Mar 2026 10:12:02 +0100 Subject: [PATCH 04/15] eval script and kl bugfix --- README.md | 14 + fiora/GNN/SpectralTrainer.py | 51 ++- fiora/GNN/Trainer.py | 40 ++- fiora/IO/LibraryLoader.py | 6 +- fiora/MS/SimulationFramework.py | 13 +- fiora/cli/eval.py | 372 +++++++++++++++++++++ fiora/cli/train.py | 2 +- pyproject.toml | 3 +- resources/data/msnlib/download_msnlib.py | 4 +- resources/data/msnlib/preprocess_msnlib.py | 30 +- scripts/fiora-eval | 6 + 11 files changed, 500 insertions(+), 41 deletions(-) create mode 100644 fiora/cli/eval.py create mode 100644 scripts/fiora-eval diff --git a/README.md b/README.md index 1647164..067ea3d 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,20 @@ fiora-train \ --precursor-modes "[M+H]+,[M-H]-,[M]+,[M]-" ``` +### Model Evaluation CLI + +You can evaluate a trained checkpoint on validation/test splits with: + +```bash +fiora-eval \ + -i resources/data/msnlib/library.csv \ + -m checkpoints/fiora.pt \ + --device cuda:0 \ + --output-dir checkpoints/eval +``` + +This prints split-level summary scores (default: `spectral_sqrt_cosine`) and writes per-split result files like `validation_eval.csv` and `test_eval.csv` when `--output-dir` is set. + ## The Algorithm FIORA has been developed as a computational tool to predict bond cleavages that occur in the MS/MS fragmentation process and estimate the probabilities of resulting fragment ions. To that end, FIORA utilizes graph neural networks to learn local molecular neighborhoods around bonds, combined with edge prediction to simulate bond dissociation. The prediction determines which fragment (left or right of the bond cleavage, with up to four possible hydrogen losses) retains the charge and which becomes the neutral loss. The figure below illustrates an example fragmentation prediction for a single bond. diff --git a/fiora/GNN/SpectralTrainer.py b/fiora/GNN/SpectralTrainer.py index 4250d61..08f2f29 100644 --- a/fiora/GNN/SpectralTrainer.py +++ b/fiora/GNN/SpectralTrainer.py @@ -25,8 +25,8 @@ def __init__( train_val_split: float = 0.8, split_by_group: bool = False, only_training: bool = False, - train_keys: List[int] = [], - val_keys: List[int] = [], + train_keys: List[int] | None = None, + val_keys: List[int] | None = None, y_tag: str = "y", metric_dict: Dict = None, problem_type: Literal[ @@ -86,6 +86,8 @@ def _build_progress_iterator(dataloader, enabled=False, desc=""): @staticmethod def _format_metric(stats): + if "kl" in stats: + return "kl", float(stats["kl"].detach().cpu().item()) if "mse" in stats: rmse = torch.sqrt(stats["mse"]) return "rmse", float(rmse.detach().cpu().item()) @@ -210,10 +212,22 @@ def _validation_loop( model.eval() y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS) if mask_name: + kwargs = {} + if with_weights: + kwargs = {"weight": batch["weight_tensor"][batch[mask_name]]} metrics.update( y_pred["fragment_probs"][batch[mask_name]], batch[self.y_tag][batch[mask_name]], + **kwargs, ) + if not rt_metric and torch.any(batch[mask_name]): + batch_loss = loss_fn( + y_pred["fragment_probs"][batch[mask_name]], + batch[self.y_tag][batch[mask_name]], + **kwargs, + ) + validation_loss += self._to_float(batch_loss) + num_batches += 1 else: kwargs = {} if with_weights: @@ -287,14 +301,14 @@ def train( self.validation_data, batch_size=batch_size, num_workers=self.num_workers, - shuffle=True, + shuffle=False, ) - using_weighted_loss_func = isinstance(loss_fn, WeightedMSELoss) | isinstance( - loss_fn, WeightedMAELoss + using_weighted_loss_func = isinstance( + loss_fn, (WeightedMSELoss, WeightedMAELoss) ) show_train_progress = len(self.training_data) > TQDM_DATA_THRESHOLD - show_val_progress = ( - (not self.only_training) and (len(self.validation_data) > TQDM_DATA_THRESHOLD) + show_val_progress = (not self.only_training) and ( + len(self.validation_data) > TQDM_DATA_THRESHOLD ) # Main loop @@ -334,12 +348,13 @@ def train( show_progress=show_val_progress, progress_desc=f"Val {e + 1}/{epochs}", ) + val_metric_name, val_metric_value = self._format_metric(val_stats) else: val_stats, val_loss = None, float("nan") + val_metric_name, val_metric_value = None, None train_metric_name, train_metric_value = self._format_metric(train_stats) if val_stats is not None: - val_metric_name, val_metric_value = self._format_metric(val_stats) val_metric_str = f"val_{val_metric_name}: {val_metric_value:.4f}" else: val_metric_str = "val_metric: n/a" @@ -356,7 +371,7 @@ def train( if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): last_lr = scheduler.get_last_lr()[0] if is_val_cycle: - scheduler.step(torch.sqrt(val_stats["mse"])) + scheduler.step(val_metric_value) if scheduler.get_last_lr()[0] < last_lr: print( f"\t >> Learning rate reduced from {last_lr:1.0e} to {scheduler.get_last_lr()[0]:1.0e}" @@ -367,13 +382,19 @@ def train( # Save history if is_val_cycle: # Update checkpoint - if val_stats["mse"].tolist() < self.checkpoint_stats["val_loss"]: + if val_metric_value < self.checkpoint_stats["val_loss"]: + checkpoint_data = { + "epoch": e + 1, + "val_loss": val_metric_value, + "val_metric_name": val_metric_name, + "sqrt_val_loss": val_metric_value, + } + if "mse" in val_stats: + checkpoint_data["sqrt_val_loss"] = self._to_float( + torch.sqrt(val_stats["mse"]) + ) self._update_checkpoint( - { - "epoch": e + 1, - "val_loss": val_stats["mse"].tolist(), - "sqrt_val_loss": torch.sqrt(val_stats["mse"]).tolist(), - }, + checkpoint_data, model, ) print(f"\t >> Set new checkpoint to epoch {e + 1}") diff --git a/fiora/GNN/Trainer.py b/fiora/GNN/Trainer.py index 17f5943..6666903 100644 --- a/fiora/GNN/Trainer.py +++ b/fiora/GNN/Trainer.py @@ -22,8 +22,8 @@ def __init__( train_val_split: float = 0.8, split_by_group: bool = False, only_training: bool = False, - train_keys: List[int] = [], - val_keys: List[int] = [], + train_keys: List[int] | None = None, + val_keys: List[int] | None = None, seed: int = 42, num_workers: int = 0, device: str = "cpu", @@ -50,10 +50,12 @@ def _split_by_group( self, data, train_val_split: float, - train_keys: List[int], - val_keys: List[int], + train_keys: List[int] | None, + val_keys: List[int] | None, seed: int, ): + train_keys = train_keys or [] + val_keys = val_keys or [] group_ids = [getattr(x, "group_id") for x in data] keys = np.unique(group_ids) if len(train_keys) > 0 and len(val_keys) > 0: @@ -94,8 +96,8 @@ def _get_default_metrics( def _init_checkpoint_system(self, save_path: str) -> None: self.checkpoint_stats = { "epoch": -1, - "val_loss": 100000.0, - "sqrt_val_loss": 100000.0, + "val_loss": float("inf"), + "sqrt_val_loss": float("inf"), "file": save_path, } @@ -115,12 +117,30 @@ def _init_history(self) -> None: "lr": [], } + @staticmethod + def _to_float(value): + if isinstance(value, torch.Tensor): + return float(value.detach().cpu().item()) + return float(value) + + def _extract_primary_error(self, stats): + if "mse" in stats: + mse = self._to_float(stats["mse"]) + return mse, float(np.sqrt(mse)) + if "mae" in stats: + mae = self._to_float(stats["mae"]) + return mae, float("nan") + key = next(iter(stats.keys())) + return self._to_float(stats[key]), float("nan") + def _update_history(self, epoch, train_stats, val_stats, lr) -> None: + train_error, train_sqrt_error = self._extract_primary_error(train_stats) + val_error, val_sqrt_error = self._extract_primary_error(val_stats) self.history["epoch"].append(epoch) - self.history["train_error"].append(train_stats["mse"]) - self.history["sqrt_train_error"].append(torch.sqrt(train_stats["mse"]).tolist()) - self.history["val_error"].append(val_stats["mse"]) - self.history["sqrt_val_error"].append(torch.sqrt(val_stats["mse"]).tolist()) + self.history["train_error"].append(train_error) + self.history["sqrt_train_error"].append(train_sqrt_error) + self.history["val_error"].append(val_error) + self.history["sqrt_val_error"].append(val_sqrt_error) self.history["lr"].append(lr) def is_group_in_training_set(self, group_id): diff --git a/fiora/IO/LibraryLoader.py b/fiora/IO/LibraryLoader.py index 7b863a6..2c74f04 100644 --- a/fiora/IO/LibraryLoader.py +++ b/fiora/IO/LibraryLoader.py @@ -9,9 +9,7 @@ def load_from_csv(self, path): return pd.read_csv(path, index_col=[0], low_memory=False) def load_from_msp(self): - # TODO IMPLEMENT - return + raise NotImplementedError("MSP loading is not implemented yet.") def clean_library(self): - # TODO IMPLEMENT + parameters for filtration - return + raise NotImplementedError("Library cleaning is not implemented yet.") diff --git a/fiora/MS/SimulationFramework.py b/fiora/MS/SimulationFramework.py index 41c5116..59b83d2 100644 --- a/fiora/MS/SimulationFramework.py +++ b/fiora/MS/SimulationFramework.py @@ -284,12 +284,23 @@ def simulate_all( suffix: str = "", groundtruth=True, min_intensity: float = 0.001, + progress: bool = False, + progress_desc: str = "Evaluate", ): with torch.no_grad(): model.eval() - for i, data in df.iterrows(): + iterator = df.iterrows() + if progress: + try: + from tqdm.auto import tqdm + + iterator = tqdm(iterator, total=len(df), desc=progress_desc) + except Exception: + pass + + for i, data in iterator: metabolite = data["Metabolite"] stats = self.simulate_and_score( metabolite, diff --git a/fiora/cli/eval.py b/fiora/cli/eval.py new file mode 100644 index 0000000..aaa3e95 --- /dev/null +++ b/fiora/cli/eval.py @@ -0,0 +1,372 @@ +#! /usr/bin/env python +import argparse +import ast +import json +import os +import re +import warnings +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from rdkit import RDLogger + +from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder +from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder +from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder +from fiora.GNN.FioraModel import FioraModel +from fiora.IO.LibraryLoader import LibraryLoader +from fiora.MOL.Metabolite import Metabolite +from fiora.MOL.MetaboliteIndex import MetaboliteIndex +from fiora.MS.SimulationFramework import SimulationFramework + +RDLogger.DisableLog("rdApp.*") +warnings.filterwarnings("ignore", category=SyntaxWarning) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + prog="fiora-eval", + description="Evaluate a trained FIORA model on validation/test splits.", + ) + parser.add_argument( + "-i", + "--input", + required=True, + help="Path to preprocessed CSV containing spectra/metadata/SMILES.", + ) + parser.add_argument( + "-m", + "--model", + required=True, + help="Path to checkpoint .pt produced by fiora-train.", + ) + parser.add_argument( + "--device", + default="auto", + help="Device to run on (e.g. cpu, cuda:0). Default: auto.", + ) + parser.add_argument( + "--datasplit-col", + default="datasplit", + help="Column containing split labels (default: datasplit).", + ) + parser.add_argument( + "--splits", + default="validation,test", + help="Comma-separated splits to evaluate (default: validation,test).", + ) + parser.add_argument( + "--score", + default="spectral_sqrt_cosine", + help="Score column to summarize after evaluation.", + ) + parser.add_argument( + "--y-label", + default="compiled_probsALL", + help="Prediction target label used during training.", + ) + parser.add_argument( + "--min-prob", + type=float, + default=0.001, + help="Minimum predicted peak intensity to keep.", + ) + parser.add_argument( + "--fragmentation-depth", + type=int, + default=1, + help="Fragmentation depth for metabolite trees.", + ) + parser.add_argument( + "--graph-mismatch-policy", + choices=["recompute", "ignore"], + default="recompute", + ) + parser.add_argument("--summary-col", default="summary") + parser.add_argument("--peaks-col", default="peaks") + parser.add_argument("--smiles-col", default="SMILES") + parser.add_argument("--group-id-col", default="group_id") + parser.add_argument("--max-rows", type=int, default=None) + parser.add_argument( + "--output-dir", + default=None, + help="Optional directory to write evaluated split CSV files.", + ) + parser.add_argument( + "--progress", + action=argparse.BooleanOptionalAction, + default=True, + help="Show tqdm progress bars (default: true).", + ) + parser.add_argument( + "--index-col", + type=int, + default=0, + help="CSV index column (default: 0). Use --no-index-col to disable.", + ) + parser.add_argument( + "--no-index-col", + action="store_true", + help="Disable index_col when reading CSV.", + ) + return parser.parse_args() + + +def _resolve_device(device: str) -> str: + if device == "auto": + return "cuda:0" if torch.cuda.is_available() else "cpu" + return device + + +def _parse_dict(val): + if isinstance(val, dict): + return val + if val is None or (isinstance(val, float) and np.isnan(val)): + return None + text = str(val).strip() + if not text: + return None + try: + return json.loads(text) + except Exception: + pass + norm = re.sub(r"\b(?:NaN|nan)\b", "None", text) + norm = re.sub(r"\b(?:Infinity|inf)\b", "1e309", norm) + norm = re.sub(r"\b(?:-Infinity|-inf)\b", "-1e309", norm) + try: + parsed = ast.literal_eval(norm) + return parsed if isinstance(parsed, dict) else None + except Exception: + return None + + +def _parse_dict_columns(df: pd.DataFrame, columns: list[str]) -> pd.DataFrame: + for col in columns: + if col in df.columns: + df[col] = df[col].apply(_parse_dict) + return df + + +def _safe_metabolite(smiles: str): + try: + return Metabolite(smiles) + except Exception: + return None + + +def _build_summary_from_columns(row): + metadata_key_map = { + "name": ["Name", "NAME", "Title", "TITLE"], + "collision_energy": ["CE", "COLLISION_ENERGY", "CollisionEnergy"], + "instrument": ["Instrument_type", "instrument", "INSTRUMENT_TYPE"], + "precursor_mode": ["Precursor_type", "ADDUCT", "PRECURSORTYPE"], + "precursor_mz": ["PrecursorMZ", "PEPMASS", "PRECURSORMZ"], + "retention_time": ["RETENTIONTIME", "RTINSECONDS", "retention_time"], + "ccs": ["CCS", "ccs"], + } + summary = {} + for key, cols in metadata_key_map.items(): + for col in cols: + if col in row.index: + value = row[col] + if value is not None and not ( + isinstance(value, float) and np.isnan(value) + ): + summary[key] = value + break + return summary + + +def _prepare_metabolites( + df: pd.DataFrame, model, progress: bool = True +) -> tuple[pd.DataFrame, int]: + setup_features = model.model_params.get( + "setup_features", + [ + "collision_energy", + "molecular_weight", + "precursor_mode", + "instrument", + "element_composition", + ], + ) + rt_features = model.model_params.get( + "rt_features", + ["molecular_weight", "precursor_mode", "instrument", "element_composition"], + ) + setup_sets = model.model_params.get("setup_features_categorical_set") + + node_encoder = AtomFeatureEncoder( + feature_list=["symbol", "num_hydrogen", "ring_type"] + ) + bond_encoder = BondFeatureEncoder(feature_list=["bond_type", "ring_type"]) + setup_encoder = CovariateFeatureEncoder( + feature_list=setup_features, sets_overwrite=setup_sets + ) + rt_encoder = CovariateFeatureEncoder( + feature_list=rt_features, sets_overwrite=setup_sets + ) + + invalid_rows = [] + iterator = df.iterrows() + if progress: + try: + from tqdm.auto import tqdm + + iterator = tqdm(iterator, total=len(df), desc="Prepare metabolites") + except Exception: + pass + + for idx, row in iterator: + smiles = row.get("SMILES") + if smiles is None or (isinstance(smiles, float) and np.isnan(smiles)): + invalid_rows.append(idx) + continue + mol = _safe_metabolite(smiles) + if mol is None: + invalid_rows.append(idx) + continue + + mol.create_molecular_structure_graph() + mol.compute_graph_attributes(node_encoder, bond_encoder) + if "group_id" in df.columns: + try: + mol.set_id(int(row["group_id"])) + except Exception: + pass + + summary = row.get("summary") + if summary is None: + summary = _build_summary_from_columns(row) + + try: + mol.add_metadata(summary, setup_encoder, rt_encoder) + except Exception: + invalid_rows.append(idx) + continue + df.at[idx, "Metabolite"] = mol + + if invalid_rows: + df = df.drop(index=invalid_rows).copy() + return df, len(invalid_rows) + + +def _load_model(path: str, dev: str): + state_path = path.replace(".pt", "_state.pt") + params_path = path.replace(".pt", "_params.json") + if os.path.exists(state_path) and os.path.exists(params_path): + return FioraModel.load_from_state_dict(path).to(dev) + return FioraModel.load(path).to(dev) + + +def _to_csv_safe(df: pd.DataFrame) -> pd.DataFrame: + out = df.copy() + if "Metabolite" in out.columns: + out = out.drop(columns=["Metabolite"]) + for col in out.columns: + if out[col].dtype == "object": + out[col] = out[col].apply( + lambda v: json.dumps(v) if isinstance(v, (dict, list)) else v + ) + return out + + +def main() -> None: + args = parse_args() + dev = _resolve_device(args.device) + np.seterr(invalid="ignore") + + index_col = None if args.no_index_col else args.index_col + loader = LibraryLoader() + df = ( + loader.load_from_csv(args.input) + if index_col == 0 + else pd.read_csv(args.input, index_col=index_col, low_memory=False) + ) + + if args.max_rows: + df = df.iloc[: args.max_rows].copy() + + df = _parse_dict_columns(df, [args.summary_col, args.peaks_col]) + splits = [x.strip() for x in args.splits.split(",") if x.strip()] + if not splits: + raise SystemExit("No valid --splits provided.") + if args.datasplit_col not in df.columns: + raise SystemExit(f"datasplit column '{args.datasplit_col}' not found in input.") + + df = df[df[args.datasplit_col].isin(splits)].copy() + print(f"Loaded {len(df)} rows for splits: {splits}") + if len(df) == 0: + raise SystemExit("No rows left after split filtering.") + + model = _load_model(args.model, dev) + model.eval() + + # Standardize user-configurable column names for downstream code. + if args.summary_col != "summary" and args.summary_col in df.columns: + df["summary"] = df[args.summary_col] + if args.peaks_col != "peaks" and args.peaks_col in df.columns: + df["peaks"] = df[args.peaks_col] + if args.smiles_col != "SMILES" and args.smiles_col in df.columns: + df["SMILES"] = df[args.smiles_col] + if args.group_id_col != "group_id" and args.group_id_col in df.columns: + df["group_id"] = df[args.group_id_col] + + df, dropped = _prepare_metabolites(df, model, progress=args.progress) + if dropped: + print(f"Dropped {dropped} invalid rows during metabolite preparation.") + + mindex = MetaboliteIndex() + mindex.index_metabolites(df["Metabolite"]) + mindex.create_fragmentation_trees(depth=args.fragmentation_depth) + mindex.add_fragmentation_trees_to_metabolite_list( + df["Metabolite"], graph_mismatch_policy=args.graph_mismatch_policy + ) + + fiora = SimulationFramework(None, dev=dev) + use_groundtruth = "peaks" in df.columns + if not use_groundtruth: + print( + "Warning: peaks column not found. Running prediction without score metrics." + ) + + output_dir = None + if args.output_dir: + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + for split in splits: + part = df[df[args.datasplit_col] == split].copy() + if part.empty: + print(f"Split '{split}': 0 rows (skipping).") + continue + part = fiora.simulate_all( + part, + model, + base_attr_name=args.y_label, + groundtruth=use_groundtruth, + min_intensity=args.min_prob, + progress=args.progress, + progress_desc=f"{split} split", + ) + + if args.score in part.columns: + score_vals = pd.to_numeric(part[args.score], errors="coerce") + print( + f"Split '{split}': n={len(part)} | " + f"{args.score}_mean={score_vals.mean():.5f} | " + f"{args.score}_median={score_vals.median():.5f}" + ) + else: + print(f"Split '{split}': n={len(part)} | score '{args.score}' not found.") + + if output_dir is not None: + out_path = output_dir / f"{split}_eval.csv" + _to_csv_safe(part).to_csv(out_path, index=False) + print(f"Wrote {len(part)} rows to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/fiora/cli/train.py b/fiora/cli/train.py index 44ee96e..ff906fc 100644 --- a/fiora/cli/train.py +++ b/fiora/cli/train.py @@ -245,7 +245,7 @@ def _load_model_params(path: str | None) -> dict: def _choose_loss(loss_name: str): if loss_name == "graphwise_kl": - return GraphwiseKLLoss(reduction="mean"), {"mse": GraphwiseKLLossMetric} + return GraphwiseKLLoss(reduction="mean"), {"kl": GraphwiseKLLossMetric} if loss_name == "weighted_mse": return WeightedMSELoss(), {"mse": WeightedMSEMetric} if loss_name == "weighted_mae": diff --git a/pyproject.toml b/pyproject.toml index 06b3a42..a939c87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,10 +34,11 @@ dependencies = [ [project.scripts] fiora-predict = "fiora.cli.predict:main" fiora-train = "fiora.cli.train:main" +fiora-eval = "fiora.cli.eval:main" [tool.setuptools] include-package-data = true -script-files = ["scripts/fiora-predict", "scripts/fiora-train"] # like setup.py `scripts=[...]` (not console entry points) +script-files = ["scripts/fiora-predict", "scripts/fiora-train", "scripts/fiora-eval"] # like setup.py `scripts=[...]` (not console entry points) [tool.setuptools.packages.find] include = ["fiora", "fiora.*"] diff --git a/resources/data/msnlib/download_msnlib.py b/resources/data/msnlib/download_msnlib.py index 632b51c..ac9feba 100644 --- a/resources/data/msnlib/download_msnlib.py +++ b/resources/data/msnlib/download_msnlib.py @@ -95,7 +95,9 @@ def _resolve_zenodo_record( resolved.append((resolved_url, filename)) if not resolved: - raise RuntimeError(f"Could not resolve any download URLs for record {record_id}") + raise RuntimeError( + f"Could not resolve any download URLs for record {record_id}" + ) return resolved diff --git a/resources/data/msnlib/preprocess_msnlib.py b/resources/data/msnlib/preprocess_msnlib.py index 7eeebcc..76d95bb 100644 --- a/resources/data/msnlib/preprocess_msnlib.py +++ b/resources/data/msnlib/preprocess_msnlib.py @@ -203,13 +203,19 @@ def _assign_reference_splits( test_size_remaining = desired_split_size - len(test) val_size_remaining = desired_split_size - len(val) - test_new_frac = test_size_remaining / len(unassigned_keys) if len(unassigned_keys) else 0 - val_new_frac = val_size_remaining / len(unassigned_keys) if len(unassigned_keys) else 0 + test_new_frac = ( + test_size_remaining / len(unassigned_keys) if len(unassigned_keys) else 0 + ) + val_new_frac = ( + val_size_remaining / len(unassigned_keys) if len(unassigned_keys) else 0 + ) if len(unassigned_keys): temp_keys, test_keys = train_test_split( unassigned_keys, test_size=test_new_frac, random_state=seed ) - adjusted_val_size = val_new_frac / (1 - test_new_frac) if (1 - test_new_frac) else 0 + adjusted_val_size = ( + val_new_frac / (1 - test_new_frac) if (1 - test_new_frac) else 0 + ) train_keys, val_keys = train_test_split( temp_keys, test_size=adjusted_val_size, random_state=seed ) @@ -301,7 +307,9 @@ def main() -> None: df["CE"] = df["CE_steps"].apply(lambda x: sum(x) / len(x) if x else np.nan) if args.filter_spectype: - allowed_spectypes = [s.strip() for s in args.allowed_spectypes.split(",") if s.strip()] + allowed_spectypes = [ + s.strip() for s in args.allowed_spectypes.split(",") if s.strip() + ] before = len(df) df = df[df["SPECTYPE"].isin(allowed_spectypes)] _log( @@ -399,7 +407,9 @@ def main() -> None: feature_list=["molecular_weight", "precursor_mode", "instrument"] ) setup_encoder.normalize_features["collision_energy"]["max"] = args.ce_upper_limit - setup_encoder.normalize_features["molecular_weight"]["max"] = args.weight_upper_limit + setup_encoder.normalize_features["molecular_weight"]["max"] = ( + args.weight_upper_limit + ) rt_encoder.normalize_features["molecular_weight"]["max"] = args.weight_upper_limit df["summary"] = df.apply( @@ -421,8 +431,10 @@ def main() -> None: ) correct_energy = df["Metabolite"].apply( - lambda x: (x.metadata["collision_energy"] <= args.ce_upper_limit) - and (x.metadata["collision_energy"] > 1) + lambda x: ( + (x.metadata["collision_energy"] <= args.ce_upper_limit) + and (x.metadata["collision_energy"] > 1) + ) ) before = len(df) df = df[correct_energy] @@ -486,7 +498,9 @@ def _split_label(gid): for col in ["peaks", "summary"]: if col in df.columns: - df[col] = df[col].apply(lambda v: json.dumps(v) if isinstance(v, dict) else v) + df[col] = df[col].apply( + lambda v: json.dumps(v) if isinstance(v, dict) else v + ) _log(f"Writing output to {args.output}", args.verbose) df.to_csv(args.output, index=False) diff --git a/scripts/fiora-eval b/scripts/fiora-eval new file mode 100644 index 0000000..93a1f18 --- /dev/null +++ b/scripts/fiora-eval @@ -0,0 +1,6 @@ +#! /usr/bin/env python +from fiora.cli.eval import main + + +if __name__ == "__main__": + main() From 5bd61a04407a5e8ca49c4e57a962c2ab1211282b Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Fri, 6 Mar 2026 10:13:03 +0100 Subject: [PATCH 05/15] test scripts for eval and trainer --- fiora/resources/__init__.py | 0 tests/test_fiora_eval.py | 50 +++++++++++++++++++++++++++++++++++ tests/test_trainer_history.py | 33 +++++++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 fiora/resources/__init__.py create mode 100644 tests/test_fiora_eval.py create mode 100644 tests/test_trainer_history.py diff --git a/fiora/resources/__init__.py b/fiora/resources/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_fiora_eval.py b/tests/test_fiora_eval.py new file mode 100644 index 0000000..b0af877 --- /dev/null +++ b/tests/test_fiora_eval.py @@ -0,0 +1,50 @@ +import io +import os +import sys + +import unittest +from unittest.mock import patch +import contextlib + +from importlib.util import spec_from_loader, module_from_spec +from importlib.machinery import SourceFileLoader + +spec = spec_from_loader( + "fiora-eval", + SourceFileLoader("fiora-eval", os.getcwd() + "/scripts/fiora-eval"), +) +fiora_eval = module_from_spec(spec) +spec.loader.exec_module(fiora_eval) +sys.modules["fiora_eval"] = fiora_eval + + +class TestFioraEval(unittest.TestCase): + def test_missing_args(self): + f = io.StringIO() + with patch("sys.argv", ["main"]): + with self.assertRaises(SystemExit) as cm, contextlib.redirect_stderr(f): + fiora_eval.main() + self.assertEqual(cm.exception.code, 2) + self.assertTrue(f.getvalue().startswith("usage:")) + + def test_help(self): + f = io.StringIO() + with patch("sys.argv", ["main", "-h"]): + with self.assertRaises(SystemExit) as cm, contextlib.redirect_stdout(f): + fiora_eval.main() + self.assertEqual(cm.exception.code, 0) + self.assertTrue(f.getvalue().startswith("usage:")) + self.assertTrue("--model MODEL" in f.getvalue()) + self.assertTrue("--splits SPLITS" in f.getvalue()) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTests( + [ + TestFioraEval("test_missing_args"), + TestFioraEval("test_help"), + ] + ) + runner = unittest.TextTestRunner() + runner.run(suite) diff --git a/tests/test_trainer_history.py b/tests/test_trainer_history.py new file mode 100644 index 0000000..e7f38ee --- /dev/null +++ b/tests/test_trainer_history.py @@ -0,0 +1,33 @@ +import math + +import torch + +from fiora.GNN.Trainer import Trainer + + +class _DummyTrainer(Trainer): + def _training_loop(self, model, dataloader, optimizer, loss_fn, **kwargs): + return None + + def _validation_loop(self, model, dataloader, loss_fn, **kwargs): + return None + + def train(self, model, optimizer, loss_fn, **kwargs): + return None + + +def test_update_history_supports_mae_only_stats(): + trainer = _DummyTrainer(data=[], only_training=True) + trainer._init_history() + trainer._update_history( + epoch=1, + train_stats={"mae": torch.tensor(0.5)}, + val_stats={"mae": torch.tensor(0.75)}, + lr=1e-3, + ) + + assert trainer.history["train_error"] == [0.5] + assert trainer.history["val_error"] == [0.75] + assert math.isnan(trainer.history["sqrt_train_error"][0]) + assert math.isnan(trainer.history["sqrt_val_error"][0]) + assert trainer.history["lr"] == [1e-3] From 877cad92128d9c2c7fe419662634eea86ad4cdad Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Fri, 6 Mar 2026 10:40:58 +0100 Subject: [PATCH 06/15] fiora model info --- README.md | 12 ++++- fiora/cli/eval.py | 52 +++++++++++++++---- fiora/cli/model_info.py | 93 ++++++++++++++++++++++++++++++++++ pyproject.toml | 3 +- scripts/fiora-model-info | 6 +++ tests/test_fiora_model_info.py | 50 ++++++++++++++++++ 6 files changed, 205 insertions(+), 11 deletions(-) create mode 100644 fiora/cli/model_info.py create mode 100644 scripts/fiora-model-info create mode 100644 tests/test_fiora_model_info.py diff --git a/README.md b/README.md index 067ea3d..f6039ed 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,17 @@ fiora-eval \ --output-dir checkpoints/eval ``` -This prints split-level summary scores (default: `spectral_sqrt_cosine`) and writes per-split result files like `validation_eval.csv` and `test_eval.csv` when `--output-dir` is set. +This prints split-level summary scores (default: `spectral_sqrt_cosine`) and, when available, also reports precursor-excluded metrics (`spectral_sqrt_cosine_wo_prec`, `spectral_sqrt_cosine_avg`). Per-split result files like `validation_eval.csv` and `test_eval.csv` are written when `--output-dir` is set. + +### Model Info CLI + +To inspect key parameters of a trained model checkpoint: + +```bash +fiora-model-info -m checkpoints/fiora.pt +``` + +Use `--as-json` to print the full `model_params` dictionary. ## The Algorithm diff --git a/fiora/cli/eval.py b/fiora/cli/eval.py index aaa3e95..548fe13 100644 --- a/fiora/cli/eval.py +++ b/fiora/cli/eval.py @@ -62,6 +62,12 @@ def parse_args() -> argparse.Namespace: default="spectral_sqrt_cosine", help="Score column to summarize after evaluation.", ) + parser.add_argument( + "--print-wo-prec", + action=argparse.BooleanOptionalAction, + default=True, + help="Also print precursor-excluded score summaries when available (default: true).", + ) parser.add_argument( "--y-label", default="compiled_probsALL", @@ -273,6 +279,13 @@ def _to_csv_safe(df: pd.DataFrame) -> pd.DataFrame: return out +def _metric_stats(part: pd.DataFrame, metric: str) -> tuple[float, float] | None: + if metric not in part.columns: + return None + vals = pd.to_numeric(part[metric], errors="coerce") + return float(vals.mean()), float(vals.median()) + + def main() -> None: args = parse_args() dev = _resolve_device(args.device) @@ -337,6 +350,7 @@ def main() -> None: output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) + summary_table: dict[str, dict[str, tuple[float, float]]] = {} for split in splits: part = df[df[args.datasplit_col] == split].copy() if part.empty: @@ -352,21 +366,41 @@ def main() -> None: progress_desc=f"{split} split", ) - if args.score in part.columns: - score_vals = pd.to_numeric(part[args.score], errors="coerce") - print( - f"Split '{split}': n={len(part)} | " - f"{args.score}_mean={score_vals.mean():.5f} | " - f"{args.score}_median={score_vals.median():.5f}" - ) - else: - print(f"Split '{split}': n={len(part)} | score '{args.score}' not found.") + metrics_to_report = [args.score] + if args.print_wo_prec: + for metric in ["spectral_sqrt_cosine_wo_prec", "spectral_sqrt_cosine_avg"]: + if metric != args.score: + metrics_to_report.append(metric) + + summaries = [] + for metric in metrics_to_report: + stats = _metric_stats(part, metric) + if stats is None: + if metric == args.score: + summaries.append(f"score '{args.score}' not found") + continue + mean, median = stats + summary_table.setdefault(metric, {})[split] = (mean, median) + summaries.append(f"{metric}_mean={mean:.5f} | {metric}_median={median:.5f}") + + print(f"Split '{split}': n={len(part)} | " + " | ".join(summaries)) if output_dir is not None: out_path = output_dir / f"{split}_eval.csv" _to_csv_safe(part).to_csv(out_path, index=False) print(f"Wrote {len(part)} rows to {out_path}") + if summary_table: + table = pd.DataFrame( + index=list(summary_table.keys()), columns=splits, dtype=object + ) + for metric, split_stats in summary_table.items(): + for split, (mean, median) in split_stats.items(): + table.at[metric, split] = f"{mean:.5f} / {median:.5f}" + table = table.fillna("-") + print("\nSummary Table (mean / median):") + print(table.to_string()) + if __name__ == "__main__": main() diff --git a/fiora/cli/model_info.py b/fiora/cli/model_info.py new file mode 100644 index 0000000..0a48d64 --- /dev/null +++ b/fiora/cli/model_info.py @@ -0,0 +1,93 @@ +#! /usr/bin/env python +import argparse +import json +import os + +from fiora.GNN.FioraModel import FioraModel + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + prog="fiora-model-info", + description="Load a FIORA .pt checkpoint and print key model parameters.", + ) + parser.add_argument( + "-m", + "--model", + required=True, + help="Path to checkpoint .pt file.", + ) + parser.add_argument( + "--as-json", + action=argparse.BooleanOptionalAction, + default=False, + help="Print the full model_params dictionary as JSON.", + ) + return parser.parse_args() + + +def _load_model(path: str): + state_path = path.replace(".pt", "_state.pt") + params_path = path.replace(".pt", "_params.json") + if os.path.exists(state_path) and os.path.exists(params_path): + return FioraModel.load_from_state_dict(path) + return FioraModel.load(path) + + +def _get(params: dict, key: str, default="-"): + return params.get(key, default) + + +def _print_main_params(params: dict) -> None: + keys = [ + "version", + "version_number", + "param_tag", + "training_label", + "gnn_type", + "depth", + "hidden_dimension", + "embedding_dimension", + "embedding_aggregation", + "layer_stacking", + "residual_connections", + "layer_norm", + "subgraph_features", + "pooling_func", + "dense_layers", + "dense_dim", + "input_dropout", + "latent_dropout", + "output_dimension", + "static_feature_dimension", + "static_rt_feature_dimension", + "prepare_additional_layers", + "rt_supported", + "ccs_supported", + ] + print("Model parameters:") + for key in keys: + print(f" {key}: {_get(params, key)}") + + # concise feature summary + atom_features = _get(params, "atom_features", []) + setup_features = _get(params, "setup_features", []) + rt_features = _get(params, "rt_features", []) + print(" atom_features:", atom_features if atom_features else "-") + print(" setup_features:", setup_features if setup_features else "-") + print(" rt_features:", rt_features if rt_features else "-") + + +def main() -> None: + args = parse_args() + model = _load_model(args.model) + params = model.model_params if hasattr(model, "model_params") else {} + print(f"Loaded model: {args.model}") + if args.as_json: + print(json.dumps(params, indent=2, sort_keys=True)) + return + _print_main_params(params) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index a939c87..e3dfbcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,10 +35,11 @@ dependencies = [ fiora-predict = "fiora.cli.predict:main" fiora-train = "fiora.cli.train:main" fiora-eval = "fiora.cli.eval:main" +fiora-model-info = "fiora.cli.model_info:main" [tool.setuptools] include-package-data = true -script-files = ["scripts/fiora-predict", "scripts/fiora-train", "scripts/fiora-eval"] # like setup.py `scripts=[...]` (not console entry points) +script-files = ["scripts/fiora-predict", "scripts/fiora-train", "scripts/fiora-eval", "scripts/fiora-model-info"] # like setup.py `scripts=[...]` (not console entry points) [tool.setuptools.packages.find] include = ["fiora", "fiora.*"] diff --git a/scripts/fiora-model-info b/scripts/fiora-model-info new file mode 100644 index 0000000..18f1019 --- /dev/null +++ b/scripts/fiora-model-info @@ -0,0 +1,6 @@ +#! /usr/bin/env python +from fiora.cli.model_info import main + + +if __name__ == "__main__": + main() diff --git a/tests/test_fiora_model_info.py b/tests/test_fiora_model_info.py new file mode 100644 index 0000000..2387758 --- /dev/null +++ b/tests/test_fiora_model_info.py @@ -0,0 +1,50 @@ +import io +import os +import sys + +import unittest +from unittest.mock import patch +import contextlib + +from importlib.util import spec_from_loader, module_from_spec +from importlib.machinery import SourceFileLoader + +spec = spec_from_loader( + "fiora-model-info", + SourceFileLoader("fiora-model-info", os.getcwd() + "/scripts/fiora-model-info"), +) +fiora_model_info = module_from_spec(spec) +spec.loader.exec_module(fiora_model_info) +sys.modules["fiora_model_info"] = fiora_model_info + + +class TestFioraModelInfo(unittest.TestCase): + def test_missing_args(self): + f = io.StringIO() + with patch("sys.argv", ["main"]): + with self.assertRaises(SystemExit) as cm, contextlib.redirect_stderr(f): + fiora_model_info.main() + self.assertEqual(cm.exception.code, 2) + self.assertTrue(f.getvalue().startswith("usage:")) + + def test_help(self): + f = io.StringIO() + with patch("sys.argv", ["main", "-h"]): + with self.assertRaises(SystemExit) as cm, contextlib.redirect_stdout(f): + fiora_model_info.main() + self.assertEqual(cm.exception.code, 0) + self.assertTrue(f.getvalue().startswith("usage:")) + self.assertTrue("--model MODEL" in f.getvalue()) + self.assertTrue("--as-json" in f.getvalue()) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTests( + [ + TestFioraModelInfo("test_missing_args"), + TestFioraModelInfo("test_help"), + ] + ) + runner = unittest.TextTestRunner() + runner.run(suite) From f44b54524eee3e821fbf74b0b2b1210bb59c617f Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Fri, 6 Mar 2026 13:04:59 +0100 Subject: [PATCH 07/15] lightning fabric --- README.md | 6 + constraints.txt | 1 + fiora/GNN/SpectralTrainer.py | 409 ----------------------------- fiora/GNN/fabric_training.py | 487 +++++++++++++++++++++++++++++++++++ fiora/cli/train.py | 135 ++++++---- notebooks/train_model.ipynb | 126 +++++---- pyproject.toml | 1 + requirements.txt | 1 + 8 files changed, 658 insertions(+), 508 deletions(-) delete mode 100644 fiora/GNN/SpectralTrainer.py create mode 100644 fiora/GNN/fabric_training.py diff --git a/README.md b/README.md index f6039ed..a927797 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,12 @@ fiora-train \ --precursor-modes "[M+H]+,[M-H]-,[M]+,[M]-" ``` +To persist per-epoch training history, add `--history-out` (supports `.json` or `.csv`): + +```bash +fiora-train ... --history-out checkpoints/fiora_history.json +``` + ### Model Evaluation CLI You can evaluate a trained checkpoint on validation/test splits with: diff --git a/constraints.txt b/constraints.txt index 7332e87..583cf8c 100644 --- a/constraints.txt +++ b/constraints.txt @@ -142,6 +142,7 @@ threadpoolctl==3.6.0 tinycss2==1.4.0 tomli==2.2.1 torch==2.6.0 +lightning-fabric==2.6.1 torch-geometric==2.6.1 torchmetrics==1.8.1 tornado==6.4.2 diff --git a/fiora/GNN/SpectralTrainer.py b/fiora/GNN/SpectralTrainer.py deleted file mode 100644 index 08f2f29..0000000 --- a/fiora/GNN/SpectralTrainer.py +++ /dev/null @@ -1,409 +0,0 @@ -import numpy as np -import torch -from torch.utils.data import DataLoader -import torch_geometric.loader as geom_loader -from torchmetrics import ( - MetricTracker, - MetricCollection, -) -from typing import Literal, List, Any, Dict - -from fiora.GNN.Trainer import Trainer -from fiora.GNN.Losses import WeightedMSELoss, WeightedMAELoss - -""" -GNN Trainer -""" - -TQDM_DATA_THRESHOLD = 10000 - - -class SpectralTrainer(Trainer): - def __init__( - self, - data: Any, - train_val_split: float = 0.8, - split_by_group: bool = False, - only_training: bool = False, - train_keys: List[int] | None = None, - val_keys: List[int] | None = None, - y_tag: str = "y", - metric_dict: Dict = None, - problem_type: Literal[ - "classification", "regression", "softmax_regression" - ] = "classification", - library: Literal["standard", "geometric"] = "geometric", - num_workers: int = 0, - seed: int = 42, - device: str = "cpu", - ): - - super().__init__( - data, - train_val_split, - split_by_group, - only_training, - train_keys, - val_keys, - seed, - num_workers, - device, - ) - self.y_tag = y_tag - self.problem_type = problem_type - - # Initialize torch metrics based on dictionary - if metric_dict: - self.metrics = { - data_split: MetricTracker( - MetricCollection({t: M() for t, M in metric_dict.items()}), - maximize=False, - ).to(device) - for data_split in ["train", "val", "masked_val", "test"] - } - else: - self.metrics = self._get_default_metrics(problem_type) - self.loader_base = ( - geom_loader.DataLoader if library == "geometric" else DataLoader - ) - - @staticmethod - def _to_float(value): - if isinstance(value, torch.Tensor): - return float(value.detach().cpu().item()) - return float(value) - - @staticmethod - def _build_progress_iterator(dataloader, enabled=False, desc=""): - if not enabled: - return dataloader - try: - from tqdm.auto import tqdm - - return tqdm(dataloader, total=len(dataloader), desc=desc, leave=False) - except Exception: - return dataloader - - @staticmethod - def _format_metric(stats): - if "kl" in stats: - return "kl", float(stats["kl"].detach().cpu().item()) - if "mse" in stats: - rmse = torch.sqrt(stats["mse"]) - return "rmse", float(rmse.detach().cpu().item()) - if "mae" in stats: - return "mae", float(stats["mae"].detach().cpu().item()) - if "acc" in stats: - return "acc", float(stats["acc"].detach().cpu().item()) - key = next(iter(stats.keys())) - val = stats[key] - if isinstance(val, torch.Tensor): - val = float(val.detach().cpu().item()) - return key, float(val) - - def _training_loop( - self, - model, - dataloader, - optimizer, - loss_fn, - metrics, - with_weights=False, - with_RT=False, - with_CCS=False, - rt_metric=False, - show_progress=False, - progress_desc="Train", - ): - training_loss = 0 - metrics.increment() - num_batches = 0 - - iterator = self._build_progress_iterator( - dataloader, enabled=show_progress, desc=progress_desc - ) - for _, batch in enumerate(iterator): - # Feed forward - model.train() - - y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS) - kwargs = {} - if with_weights: - kwargs = {"weight": batch["weight_tensor"]} - if getattr(loss_fn, "requires_segment_ptr", False): - kwargs["segment_ptr"] = y_pred.get("segment_ptr") - - # Compute loss - loss = loss_fn( - y_pred["fragment_probs"], batch[self.y_tag], **kwargs - ) # with logits - if not rt_metric: - metrics( - y_pred["fragment_probs"], batch[self.y_tag], **kwargs - ) # call update - - # Add RT and CCS to loss - if with_RT: - if with_weights: - kwargs["weight"] = batch["weight"][batch["retention_mask"]] - loss_rt = loss_fn( - y_pred["rt"][batch["retention_mask"]], - batch["retention_time"][batch["retention_mask"]], - **kwargs, - ) - loss = loss + loss_rt - - if with_CCS: - if with_weights: - kwargs["weight"] = batch["weight"][batch["ccs_mask"]] - loss_ccs = loss_fn( - y_pred["ccs"][batch["ccs_mask"]], - batch["ccs"][batch["ccs_mask"]], - **kwargs, - ) - loss = loss + loss_ccs - - if rt_metric: - metrics( - y_pred["rt"][batch["retention_mask"]], - batch["retention_time"][batch["retention_mask"]], - **kwargs, - ) # call update - metrics( - y_pred["ccs"][batch["ccs_mask"]], - batch["ccs"][batch["ccs_mask"]], - **kwargs, - ) # call update - - # Backpropagate - optimizer.zero_grad() - loss.backward() - optimizer.step() - training_loss += self._to_float(loss) - num_batches += 1 - - # End of training cycle: Evaluation - stats = metrics.compute() - training_loss /= max(num_batches, 1) - return stats, training_loss - - def _validation_loop( - self, - model, - dataloader, - loss_fn, - metrics, - with_weights=False, - with_RT=False, - with_CCS=False, - rt_metric=False, - mask_name=None, - show_progress=False, - progress_desc="Validation", - ): - metrics.increment() - validation_loss = 0 - num_batches = 0 - with torch.no_grad(): - iterator = self._build_progress_iterator( - dataloader, enabled=show_progress, desc=progress_desc - ) - for _, batch in enumerate(iterator): - model.eval() - y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS) - if mask_name: - kwargs = {} - if with_weights: - kwargs = {"weight": batch["weight_tensor"][batch[mask_name]]} - metrics.update( - y_pred["fragment_probs"][batch[mask_name]], - batch[self.y_tag][batch[mask_name]], - **kwargs, - ) - if not rt_metric and torch.any(batch[mask_name]): - batch_loss = loss_fn( - y_pred["fragment_probs"][batch[mask_name]], - batch[self.y_tag][batch[mask_name]], - **kwargs, - ) - validation_loss += self._to_float(batch_loss) - num_batches += 1 - else: - kwargs = {} - if with_weights: - kwargs = {"weight": batch["weight_tensor"]} - if getattr(loss_fn, "requires_segment_ptr", False): - kwargs["segment_ptr"] = y_pred.get("segment_ptr") - if not rt_metric: - metrics.update( - y_pred["fragment_probs"], batch[self.y_tag], **kwargs - ) - batch_loss = loss_fn( - y_pred["fragment_probs"], batch[self.y_tag], **kwargs - ) - validation_loss += self._to_float(batch_loss) - num_batches += 1 - if rt_metric: - metrics( - y_pred["rt"][batch["retention_mask"]], - batch["retention_time"][batch["retention_mask"]], - **kwargs, - ) # call update - metrics( - y_pred["ccs"][batch["ccs_mask"]], - batch["ccs"][batch["ccs_mask"]], - **kwargs, - ) # call update - - # End of Validation cycle - stats = metrics.compute() - if num_batches > 0: - validation_loss /= num_batches - else: - validation_loss = float("nan") - return stats, validation_loss - - # Training function - def train( - self, - model, - optimizer, - loss_fn, - scheduler=None, - batch_size=16, - epochs=2, - val_every_n_epochs=1, - use_validation_mask=False, - with_RT=True, - with_CCS=True, - rt_metric=False, - mask_name="validation_mask", - save_path: str | None = None, - tag="", - ) -> Dict[str, Any]: - - # Set up checkpoint system and model info - if save_path is None: - save_path = f"../../checkpoint_{tag}.best.pt" - self._init_checkpoint_system(save_path=save_path) - self._init_history() - model.model_params["training_label"] = self.y_tag - - # Stage data into dataloader - training_loader = self.loader_base( - self.training_data, - batch_size=batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - if not self.only_training: - validation_loader = self.loader_base( - self.validation_data, - batch_size=batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - using_weighted_loss_func = isinstance( - loss_fn, (WeightedMSELoss, WeightedMAELoss) - ) - show_train_progress = len(self.training_data) > TQDM_DATA_THRESHOLD - show_val_progress = (not self.only_training) and ( - len(self.validation_data) > TQDM_DATA_THRESHOLD - ) - - # Main loop - for e in range(epochs): - # Training - train_stats, train_loss = self._training_loop( - model, - training_loader, - optimizer, - loss_fn, - self.metrics["train"], - with_weights=using_weighted_loss_func, - with_RT=with_RT, - with_CCS=with_CCS, - rt_metric=rt_metric, - show_progress=show_train_progress, - progress_desc=f"Train {e + 1}/{epochs}", - ) - - # Validation - is_val_cycle = not self.only_training and ( - (e + 1) % val_every_n_epochs == 0 - ) - if is_val_cycle: - val_stats, val_loss = self._validation_loop( - model, - validation_loader, - loss_fn, - self.metrics["masked_val"] - if use_validation_mask - else self.metrics["val"], - with_weights=using_weighted_loss_func, - with_RT=with_RT, - with_CCS=with_CCS, - rt_metric=rt_metric, - mask_name=mask_name if use_validation_mask else None, - show_progress=show_val_progress, - progress_desc=f"Val {e + 1}/{epochs}", - ) - val_metric_name, val_metric_value = self._format_metric(val_stats) - else: - val_stats, val_loss = None, float("nan") - val_metric_name, val_metric_value = None, None - - train_metric_name, train_metric_value = self._format_metric(train_stats) - if val_stats is not None: - val_metric_str = f"val_{val_metric_name}: {val_metric_value:.4f}" - else: - val_metric_str = "val_metric: n/a" - val_loss_str = f"{val_loss:.4f}" if not np.isnan(val_loss) else "n/a" - print( - f"Epoch {e + 1}/{epochs} - loss: {train_loss:.4f} - " - f"val_loss: {val_loss_str} - " - f"train_{train_metric_name}: {train_metric_value:.4f} - " - f"{val_metric_str}" - ) - - # End of epoch: Advance scheduler - if scheduler: - if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - last_lr = scheduler.get_last_lr()[0] - if is_val_cycle: - scheduler.step(val_metric_value) - if scheduler.get_last_lr()[0] < last_lr: - print( - f"\t >> Learning rate reduced from {last_lr:1.0e} to {scheduler.get_last_lr()[0]:1.0e}" - ) - else: - scheduler.step() - - # Save history - if is_val_cycle: - # Update checkpoint - if val_metric_value < self.checkpoint_stats["val_loss"]: - checkpoint_data = { - "epoch": e + 1, - "val_loss": val_metric_value, - "val_metric_name": val_metric_name, - "sqrt_val_loss": val_metric_value, - } - if "mse" in val_stats: - checkpoint_data["sqrt_val_loss"] = self._to_float( - torch.sqrt(val_stats["mse"]) - ) - self._update_checkpoint( - checkpoint_data, - model, - ) - print(f"\t >> Set new checkpoint to epoch {e + 1}") - current_lr = ( - scheduler.get_last_lr()[0] - if scheduler is not None - else optimizer.param_groups[0]["lr"] - ) - self._update_history(e + 1, train_stats, val_stats, lr=current_lr) - - print("Finished Training!") - return self.checkpoint_stats diff --git a/fiora/GNN/fabric_training.py b/fiora/GNN/fabric_training.py new file mode 100644 index 0000000..5234355 --- /dev/null +++ b/fiora/GNN/fabric_training.py @@ -0,0 +1,487 @@ +import random +from typing import Callable + +import numpy as np +import torch +import torch_geometric.loader as geom_loader +from lightning_fabric import Fabric +from torchmetrics import MeanSquaredError + +from fiora.GNN.Losses import WeightedMAELoss, WeightedMSELoss + +TQDM_DATA_THRESHOLD = 10000 + + +def is_weighted_loss(loss_fn) -> bool: + return isinstance(loss_fn, (WeightedMSELoss, WeightedMAELoss)) + + +def build_loss_kwargs( + batch, + y_pred, + loss_fn, + with_weights: bool, + mask: torch.Tensor | None = None, + include_segment_ptr: bool = True, +): + kwargs = {} + if with_weights: + kwargs["weight"] = ( + batch["weight_tensor"] if mask is None else batch["weight_tensor"][mask] + ) + if include_segment_ptr and getattr(loss_fn, "requires_segment_ptr", False): + kwargs["segment_ptr"] = y_pred.get("segment_ptr") + return kwargs + + +def add_rt_ccs_loss( + loss, + y_pred, + batch, + loss_fn, + with_weights: bool, + with_rt: bool, + with_ccs: bool, +): + if with_rt: + kwargs_rt = {} + if with_weights: + kwargs_rt["weight"] = batch["weight"][batch["retention_mask"]] + loss = loss + loss_fn( + y_pred["rt"][batch["retention_mask"]], + batch["retention_time"][batch["retention_mask"]], + **kwargs_rt, + ) + if with_ccs: + kwargs_ccs = {} + if with_weights: + kwargs_ccs["weight"] = batch["weight"][batch["ccs_mask"]] + loss = loss + loss_fn( + y_pred["ccs"][batch["ccs_mask"]], + batch["ccs"][batch["ccs_mask"]], + **kwargs_ccs, + ) + return loss + + +def safe_metric_update(metric, preds, target, kwargs: dict | None = None): + kwargs = kwargs or {} + update = getattr(metric, "update", None) + if callable(update): + try: + update(preds, target, **kwargs) + return + except TypeError: + update(preds, target) + return + try: + metric(preds, target, **kwargs) + except TypeError: + metric(preds, target) + + +def metric_label_and_value(metric_or_stats, preferred_key: str | None = None): + stats = ( + metric_or_stats.compute() + if hasattr(metric_or_stats, "compute") + else metric_or_stats + ) + + if isinstance(stats, dict): + if preferred_key is not None and preferred_key in stats: + key = preferred_key + else: + for candidate in ("kl", "mse", "mae", "acc"): + if candidate in stats: + key = candidate + break + else: + key = next(iter(stats.keys())) + value = stats[key] + else: + key = preferred_key or "metric" + value = stats + + label = "rmse" if key == "mse" else key + if key == "mse": + value = torch.sqrt(value) + if isinstance(value, torch.Tensor): + value = float(value.detach().cpu().item()) + else: + value = float(value) + return label, value + + +def resolve_fabric_runtime(device: str): + if device.startswith("cuda"): + if ":" in device: + return "cuda", [int(device.split(":")[-1])] + return "cuda", 1 + if device.startswith("mps"): + return "mps", 1 + return "cpu", 1 + + +def seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def build_progress_iterator(dataloader, enabled=False, desc=""): + if not enabled: + return dataloader + try: + from tqdm.auto import tqdm + + return tqdm(dataloader, total=len(dataloader), desc=desc, leave=False) + except Exception: + return dataloader + + +def unwrap_model(model): + return model.module if hasattr(model, "module") else model + + +def run_epoch( + fabric: Fabric, + model: torch.nn.Module, + dataloader, + loss_fn, + metric, + metric_name: str, + y_tag: str, + with_weights: bool, + with_rt: bool, + with_ccs: bool, + rt_metric: bool, + optimizer=None, + use_validation_mask: bool = False, + mask_name: str = "validation_mask", + show_progress: bool = False, + progress_desc: str = "", +): + is_training = optimizer is not None + if is_training: + model.train() + else: + model.eval() + metric.reset() + + loss_total = 0.0 + loss_batches = 0 + iterator = build_progress_iterator( + dataloader, enabled=show_progress, desc=progress_desc + ) + + for batch in iterator: + batch = batch.to(fabric.device) + with torch.set_grad_enabled(is_training): + y_pred = model(batch, with_RT=with_rt, with_CCS=with_ccs) + + if use_validation_mask: + mask = batch[mask_name] + if torch.any(mask): + kwargs = build_loss_kwargs( + batch=batch, + y_pred=y_pred, + loss_fn=loss_fn, + with_weights=with_weights, + mask=mask, + include_segment_ptr=False, + ) + loss = loss_fn( + y_pred["fragment_probs"][mask], + batch[y_tag][mask], + **kwargs, + ) + if not rt_metric: + safe_metric_update( + metric, + y_pred["fragment_probs"][mask], + batch[y_tag][mask], + kwargs, + ) + else: + if with_rt: + safe_metric_update( + metric, + y_pred["rt"][batch["retention_mask"]], + batch["retention_time"][batch["retention_mask"]], + {}, + ) + if with_ccs: + safe_metric_update( + metric, + y_pred["ccs"][batch["ccs_mask"]], + batch["ccs"][batch["ccs_mask"]], + {}, + ) + loss = add_rt_ccs_loss( + loss=loss, + y_pred=y_pred, + batch=batch, + loss_fn=loss_fn, + with_weights=with_weights, + with_rt=with_rt, + with_ccs=with_ccs, + ) + loss_total += float(loss.detach().cpu().item()) + loss_batches += 1 + continue + + kwargs = build_loss_kwargs( + batch=batch, + y_pred=y_pred, + loss_fn=loss_fn, + with_weights=with_weights, + include_segment_ptr=True, + ) + loss = loss_fn(y_pred["fragment_probs"], batch[y_tag], **kwargs) + if not rt_metric: + safe_metric_update( + metric, y_pred["fragment_probs"], batch[y_tag], kwargs + ) + else: + if with_rt: + safe_metric_update( + metric, + y_pred["rt"][batch["retention_mask"]], + batch["retention_time"][batch["retention_mask"]], + {}, + ) + if with_ccs: + safe_metric_update( + metric, + y_pred["ccs"][batch["ccs_mask"]], + batch["ccs"][batch["ccs_mask"]], + {}, + ) + + loss = add_rt_ccs_loss( + loss=loss, + y_pred=y_pred, + batch=batch, + loss_fn=loss_fn, + with_weights=with_weights, + with_rt=with_rt, + with_ccs=with_ccs, + ) + loss_total += float(loss.detach().cpu().item()) + loss_batches += 1 + + if is_training: + optimizer.zero_grad(set_to_none=True) + fabric.backward(loss) + optimizer.step() + + avg_loss = loss_total / max(loss_batches, 1) if loss_batches > 0 else float("nan") + metric_label, metric_value = metric_label_and_value( + metric, preferred_key=metric_name + ) + return avg_loss, metric_label, metric_value + + +def train_fabric_loop( + *, + model, + train_data, + val_data, + loss_fn, + metric_dict, + y_label: str, + device: str, + batch_size: int, + num_workers: int, + epochs: int, + val_every: int, + learning_rate: float, + weight_decay: float, + scheduler_name: str, + scheduler_patience: int, + scheduler_factor: float, + with_rt: bool, + with_ccs: bool, + rt_metric: bool, + use_validation_mask: bool, + validation_mask_name: str, + output_path: str | None = None, + optimizer=None, + scheduler=None, + progress_threshold: int = TQDM_DATA_THRESHOLD, + launch_fabric: bool = True, + logger: Callable[[str], None] | None = print, +): + has_validation = len(val_data) > 0 + accelerator, devices = resolve_fabric_runtime(device) + fabric = Fabric(accelerator=accelerator, devices=devices) + if launch_fabric: + fabric.launch() + + with_weights = is_weighted_loss(loss_fn) + if metric_dict: + metric_name, metric_cls = next(iter(metric_dict.items())) + train_metric = metric_cls().to(fabric.device) + val_metric = metric_cls().to(fabric.device) + else: + metric_name = "mse" + train_metric = MeanSquaredError().to(fabric.device) + val_metric = MeanSquaredError().to(fabric.device) + + if optimizer is None: + optimizer = torch.optim.Adam( + model.parameters(), lr=learning_rate, weight_decay=weight_decay + ) + if scheduler is None and scheduler_name == "plateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + patience=scheduler_patience, + factor=scheduler_factor, + mode="min", + ) + + train_loader = geom_loader.DataLoader( + train_data, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + ) + val_loader = None + if has_validation: + val_loader = geom_loader.DataLoader( + val_data, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + ) + + model, optimizer = fabric.setup(model, optimizer) + if val_loader is not None: + train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader) + else: + train_loader = fabric.setup_dataloaders(train_loader) + + show_train_progress = len(train_data) > progress_threshold + show_val_progress = has_validation and (len(val_data) > progress_threshold) + + best_metric = float("inf") + best_epoch = -1 + history = { + "epoch": [], + "train_error": [], + "sqrt_train_error": [], + "val_error": [], + "sqrt_val_error": [], + "lr": [], + } + + for epoch in range(1, epochs + 1): + train_loss, train_metric_label, train_metric_value = run_epoch( + fabric=fabric, + model=model, + dataloader=train_loader, + loss_fn=loss_fn, + metric=train_metric, + metric_name=metric_name, + y_tag=y_label, + with_weights=with_weights, + with_rt=with_rt, + with_ccs=with_ccs, + rt_metric=rt_metric, + optimizer=optimizer, + show_progress=show_train_progress, + progress_desc=f"Train {epoch}/{epochs}", + ) + + is_val_cycle = has_validation and (epoch % val_every == 0) + if is_val_cycle: + val_loss, val_metric_label, val_metric_value = run_epoch( + fabric=fabric, + model=model, + dataloader=val_loader, + loss_fn=loss_fn, + metric=val_metric, + metric_name=metric_name, + y_tag=y_label, + with_weights=with_weights, + with_rt=with_rt, + with_ccs=with_ccs, + rt_metric=rt_metric, + use_validation_mask=use_validation_mask, + mask_name=validation_mask_name, + show_progress=show_val_progress, + progress_desc=f"Val {epoch}/{epochs}", + ) + else: + val_loss = float("nan") + val_metric_label = train_metric_label + val_metric_value = float("nan") + + monitor_metric = None + if is_val_cycle: + monitor_metric = val_metric_value + elif not has_validation: + monitor_metric = train_metric_value + + if scheduler is not None: + prev_lr = optimizer.param_groups[0]["lr"] + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + if monitor_metric is not None and not np.isnan(monitor_metric): + scheduler.step(monitor_metric) + else: + scheduler.step() + curr_lr = optimizer.param_groups[0]["lr"] + if logger is not None and fabric.is_global_zero and curr_lr < prev_lr: + logger( + f"\t >> Learning rate reduced from {prev_lr:1.0e} to {curr_lr:1.0e}" + ) + + if monitor_metric is not None and not np.isnan(monitor_metric): + if monitor_metric < best_metric: + best_metric = monitor_metric + best_epoch = epoch + if output_path is not None and fabric.is_global_zero: + unwrap_model(model).save(output_path) + if logger is not None: + logger(f"\t >> Set new checkpoint to epoch {epoch}") + + if (is_val_cycle or not has_validation) and fabric.is_global_zero: + history["epoch"].append(epoch) + history["train_error"].append(train_metric_value) + history["sqrt_train_error"].append(train_metric_value) + history["val_error"].append( + val_metric_value if is_val_cycle else float("nan") + ) + history["sqrt_val_error"].append( + val_metric_value if is_val_cycle else float("nan") + ) + history["lr"].append(optimizer.param_groups[0]["lr"]) + + if logger is not None and fabric.is_global_zero: + val_loss_str = f"{val_loss:.4f}" if not np.isnan(val_loss) else "n/a" + val_metric_str = ( + f"{val_metric_value:.4f}" if not np.isnan(val_metric_value) else "n/a" + ) + logger( + f"Epoch {epoch}/{epochs} - " + f"loss: {train_loss:.4f} - " + f"val_loss: {val_loss_str} - " + f"train_{train_metric_label}: {train_metric_value:.4f} - " + f"val_{val_metric_label}: {val_metric_str}" + ) + + if best_epoch < 0: + best_epoch = epochs + best_metric = float("nan") + if output_path is not None and fabric.is_global_zero: + unwrap_model(model).save(output_path) + + checkpoints = { + "epoch": best_epoch, + "val_loss": best_metric, + "sqrt_val_loss": best_metric, + "file": output_path, + } + return checkpoints, history diff --git a/fiora/cli/train.py b/fiora/cli/train.py index ff906fc..ebc6913 100644 --- a/fiora/cli/train.py +++ b/fiora/cli/train.py @@ -10,11 +10,13 @@ import pandas as pd import torch from rdkit import RDLogger +from sklearn.model_selection import train_test_split from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder from fiora.GNN.FioraModel import FioraModel +from fiora.GNN.fabric_training import seed_everything, train_fabric_loop from fiora.GNN.Losses import ( GraphwiseKLLoss, GraphwiseKLLossMetric, @@ -23,7 +25,6 @@ WeightedMSELoss, WeightedMSEMetric, ) -from fiora.GNN.SpectralTrainer import SpectralTrainer from fiora.IO.LibraryLoader import LibraryLoader from fiora.MOL.Metabolite import Metabolite from fiora.MOL.MetaboliteIndex import MetaboliteIndex @@ -173,6 +174,11 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Disable index_col when reading CSV.", ) + parser.add_argument( + "--history-out", + default=None, + help="Optional path to save training history (.json or .csv).", + ) return parser.parse_args() @@ -255,10 +261,60 @@ def _choose_loss(loss_name: str): raise ValueError(f"Unknown loss: {loss_name}") +def _save_history(history: dict, output_path: str) -> None: + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + if output_path.lower().endswith(".csv"): + pd.DataFrame(history).to_csv(output_path, index=False) + else: + with open(output_path, "w") as fp: + json.dump(history, fp, indent=2) + + +def _split_geo_data( + geo_data, + split_by_group: bool, + train_val_split: float, + seed: int, + train_keys: list[int] | None = None, + val_keys: list[int] | None = None, +): + train_keys = train_keys or [] + val_keys = val_keys or [] + if len(geo_data) == 0: + return [], [] + + if split_by_group and hasattr(geo_data[0], "group_id"): + group_ids = np.array([int(getattr(x, "group_id")) for x in geo_data]) + keys = np.unique(group_ids) + if len(train_keys) > 0 and len(val_keys) > 0: + train_set = set(int(x) for x in train_keys) + val_set = set(int(x) for x in val_keys) + print("Using pre-set train/validation keys") + else: + tr, va = train_test_split( + keys, test_size=1 - train_val_split, random_state=seed + ) + train_set = set(int(x) for x in tr) + val_set = set(int(x) for x in va) + train_data = [x for x in geo_data if int(getattr(x, "group_id")) in train_set] + val_data = [x for x in geo_data if int(getattr(x, "group_id")) in val_set] + return train_data, val_data + + train_size = int(len(geo_data) * train_val_split) + rng = np.random.default_rng(seed) + indices = np.arange(len(geo_data)) + rng.shuffle(indices) + train_idx = set(indices[:train_size].tolist()) + train_data = [geo_data[i] for i in range(len(geo_data)) if i in train_idx] + val_data = [geo_data[i] for i in range(len(geo_data)) if i not in train_idx] + return train_data, val_data + + def main() -> None: args = parse_args() dev = _resolve_device(args.device) np.seterr(invalid="ignore") + seed_everything(args.seed) index_col = None if args.no_index_col else args.index_col loader = LibraryLoader() @@ -328,7 +384,6 @@ def main() -> None: } # Build metabolites - metabolites = [] invalid_rows = [] for idx, row in df.iterrows(): smiles = row.get(args.smiles_col) @@ -368,7 +423,6 @@ def main() -> None: else: mol.set_loss_weight(1.0) - metabolites.append(mol) df.at[idx, "Metabolite"] = mol if invalid_rows: @@ -450,7 +504,7 @@ def main() -> None: # Geometric data geo_data = [] for _, row in df_train.iterrows(): - data = row["Metabolite"].as_geometric_data().to(dev) + data = row["Metabolite"].as_geometric_data() if args.group_id_col in df_train.columns: try: data.group_id = int(row[args.group_id_col]) @@ -506,11 +560,11 @@ def main() -> None: state_path = args.resume.replace(".pt", "_state.pt") params_path = args.resume.replace(".pt", "_params.json") if os.path.exists(state_path) and os.path.exists(params_path): - model = FioraModel.load_from_state_dict(args.resume).to(dev) + model = FioraModel.load_from_state_dict(args.resume) else: - model = FioraModel.load(args.resume).to(dev) + model = FioraModel.load(args.resume) else: - model = FioraModel(model_params).to(dev) + model = FioraModel(model_params) if (args.with_rt or args.with_ccs) and not model.model_params.get( "prepare_additional_layers", False @@ -522,56 +576,47 @@ def main() -> None: loss_fn, metric_dict = _choose_loss(args.loss) split_by_group = args.split_by_group and args.group_id_col in df_train.columns - only_training = len(val_keys) == 0 and not args.use_validation_mask - - trainer = SpectralTrainer( + train_data, val_data = _split_geo_data( geo_data, - y_tag=args.y_label, - problem_type="regression", - train_val_split=args.train_val_split, split_by_group=split_by_group, - only_training=only_training, + train_val_split=args.train_val_split, + seed=args.seed, train_keys=train_keys, val_keys=val_keys, - metric_dict=metric_dict, - seed=args.seed, - device=dev, - num_workers=args.num_workers, ) - - optimizer = torch.optim.Adam( - model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay - ) - - scheduler = None - if args.scheduler == "plateau": - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - patience=args.scheduler_patience, - factor=args.scheduler_factor, - mode="min", - ) + has_validation = len(val_data) > 0 + print(f"Train/validation split: {len(train_data)} / {len(val_data)}") output_path = args.output os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) - - checkpoints = trainer.train( - model, - optimizer, - loss_fn, - scheduler=scheduler, + checkpoints, history = train_fabric_loop( + model=model, + train_data=train_data, + val_data=val_data, + loss_fn=loss_fn, + metric_dict=metric_dict, + y_label=args.y_label, + device=dev, batch_size=args.batch_size, + num_workers=args.num_workers, epochs=args.epochs, - val_every_n_epochs=args.val_every, - use_validation_mask=args.use_validation_mask, - with_RT=args.with_rt, - with_CCS=args.with_ccs, + val_every=args.val_every, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + scheduler_name=args.scheduler, + scheduler_patience=args.scheduler_patience, + scheduler_factor=args.scheduler_factor, + with_rt=args.with_rt, + with_ccs=args.with_ccs, rt_metric=args.rt_metric, - mask_name=args.validation_mask_name, - save_path=output_path, - tag="train", + use_validation_mask=args.use_validation_mask, + validation_mask_name=args.validation_mask_name, + output_path=output_path, + logger=print, ) - + if args.history_out: + _save_history(history, args.history_out) + print(f"Saved training history to {args.history_out}") print(f"Finished training. Best checkpoint: {checkpoints['file']}") diff --git a/notebooks/train_model.ipynb b/notebooks/train_model.ipynb index ee1f97b..90ce3b6 100644 --- a/notebooks/train_model.ipynb +++ b/notebooks/train_model.ipynb @@ -532,7 +532,7 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.GNN.SpectralTrainer import SpectralTrainer\n", + "from fiora.GNN.fabric_training import train_fabric_loop\n", "from fiora.GNN.FioraModel import FioraModel\n", "from fiora.GNN.Losses import (\n", " WeightedMSELoss,\n", @@ -543,6 +543,7 @@ " GraphwiseKLLossMetric,\n", ")\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", + "from sklearn.model_selection import train_test_split\n", "\n", "fiora = SimulationFramework(None, dev=dev)\n", "# fiora = SimulationFramework(None, dev=dev, with_RT=training_params[\"with_RT\"], with_CCS=training_params[\"with_CCS\"])\n", @@ -563,6 +564,28 @@ " loss_fn = torch.nn.MSELoss()\n", "\n", "\n", + "def split_geo_by_group(\n", + " geo_data, train_keys=None, val_keys=None, train_val_split=0.8, seed=42\n", + "):\n", + " train_keys = train_keys or []\n", + " val_keys = val_keys or []\n", + " if len(train_keys) > 0 and len(val_keys) > 0:\n", + " train_set = set(int(x) for x in train_keys)\n", + " val_set = set(int(x) for x in val_keys)\n", + " else:\n", + " group_ids = np.array([int(getattr(x, \"group_id\")) for x in geo_data])\n", + " keys = np.unique(group_ids)\n", + " tr, va = train_test_split(\n", + " keys, test_size=1 - train_val_split, random_state=seed\n", + " )\n", + " train_set = set(int(x) for x in tr)\n", + " val_set = set(int(x) for x in va)\n", + "\n", + " train_data = [x for x in geo_data if int(getattr(x, \"group_id\")) in train_set]\n", + " val_data = [x for x in geo_data if int(getattr(x, \"group_id\")) in val_set]\n", + " return train_data, val_data\n", + "\n", + "\n", "def train_new_model(\n", " continue_with_model=None,\n", " model_params=model_params,\n", @@ -581,17 +604,10 @@ " lr=training_params[\"learning_rate\"],\n", " weight_decay=training_params[\"weight_decay\"],\n", " )\n", + " save_path = f\"../../checkpoint_{tag}.best.pt\"\n", + "\n", " if all_together:\n", - " trainer = SpectralTrainer(\n", - " geo_data,\n", - " y_tag=y_label,\n", - " problem_type=\"regression\",\n", - " only_training=True,\n", - " metric_dict=metric_dict,\n", - " split_by_group=True,\n", - " seed=seed,\n", - " device=dev,\n", - " )\n", + " train_data, val_data = geo_data, []\n", " scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n", " else:\n", " train_keys, val_keys = (\n", @@ -606,36 +622,42 @@ " print(\n", " f\"Sample down to {train_fraction * 100}% with {len(train_keys)} training and {len(val_keys)} validation compounds \"\n", " )\n", - " trainer = SpectralTrainer(\n", - " geo_data,\n", - " y_tag=y_label,\n", - " problem_type=\"regression\",\n", - " train_keys=train_keys,\n", - " val_keys=val_keys,\n", - " metric_dict=metric_dict,\n", - " split_by_group=True,\n", - " seed=seed,\n", - " device=dev,\n", + " train_data, val_data = split_geo_by_group(\n", + " geo_data, train_keys=train_keys, val_keys=val_keys, seed=seed\n", " )\n", " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", " optimizer, patience=8, factor=0.5, mode=\"min\"\n", " )\n", "\n", - " checkpoints = trainer.train(\n", - " model,\n", - " optimizer,\n", - " loss_fn,\n", - " scheduler=scheduler,\n", + " checkpoints, history = train_fabric_loop(\n", + " model=model,\n", + " train_data=train_data,\n", + " val_data=val_data,\n", + " loss_fn=loss_fn,\n", + " metric_dict=metric_dict,\n", + " y_label=y_label,\n", + " device=dev,\n", " batch_size=training_params[\"batch_size\"],\n", + " num_workers=0,\n", " epochs=training_params[\"epochs\"],\n", - " val_every_n_epochs=1,\n", - " with_CCS=training_params[\"with_CCS\"],\n", - " with_RT=training_params[\"with_RT\"],\n", + " val_every=val_interval,\n", + " learning_rate=training_params[\"learning_rate\"],\n", + " weight_decay=training_params[\"weight_decay\"],\n", + " scheduler_name=\"none\",\n", + " scheduler_patience=8,\n", + " scheduler_factor=0.5,\n", + " with_rt=training_params[\"with_RT\"],\n", + " with_ccs=training_params[\"with_CCS\"],\n", + " rt_metric=False,\n", " use_validation_mask=False,\n", - " tag=tag,\n", - " ) # , mask_name=\"compiled_validation_maskALL\")\n", + " validation_mask_name=\"validation_mask\",\n", + " output_path=save_path,\n", + " optimizer=optimizer,\n", + " scheduler=scheduler,\n", + " logger=print,\n", + " )\n", " print(checkpoints)\n", - " return model, checkpoints, trainer\n", + " return model, checkpoints, history\n", "\n", "\n", "def simulate_all(model, DF):\n", @@ -677,7 +699,7 @@ "\n", " # Train the model with the updated parameters\n", " try:\n", - " model, checkpoints, trainer = train_new_model(\n", + " model, checkpoints, history = train_new_model(\n", " model_params=model_params,\n", " training_params=training_params,\n", " tag=f\"{prefix}_{i + 1}\",\n", @@ -687,7 +709,7 @@ " \"config\": param_override,\n", " \"model\": model if store_models else None,\n", " \"checkpoints\": checkpoints,\n", - " \"trainer\": trainer,\n", + " \"history\": history,\n", " }\n", " )\n", "\n", @@ -1532,7 +1554,7 @@ "\n", "if not GRID_SEARCH:\n", " print(f\"Training model\")\n", - " model, checkpoints, trainer = train_new_model() # continue_with_model=model)" + " model, checkpoints, history = train_new_model() # continue_with_model=model)" ] }, { @@ -1562,7 +1584,7 @@ " best_result = min(grid_results, key=lambda x: x[\"checkpoints\"][\"sqrt_val_loss\"])\n", " model = best_result[\"model\"]\n", " checkpoints = best_result[\"checkpoints\"]\n", - " trainer = best_result[\"trainer\"]\n", + " history = best_result[\"history\"]\n", " print(\n", " f\"Best model found with val_sqrt_error: {best_result['checkpoints']['sqrt_val_loss']}\"\n", " )\n", @@ -1599,26 +1621,26 @@ "import matplotlib.pyplot as plt\n", "\n", "# Convert numpy arrays to scalars if they are single-element arrays\n", - "trainer.history[\"train_error\"] = [\n", + "history[\"train_error\"] = [\n", " error.item() if isinstance(error, np.ndarray) and error.size == 1 else error\n", - " for error in trainer.history[\"train_error\"]\n", + " for error in history[\"train_error\"]\n", "]\n", - "trainer.history[\"val_error\"] = [\n", + "history[\"val_error\"] = [\n", " error.item() if isinstance(error, np.ndarray) and error.size == 1 else error\n", - " for error in trainer.history[\"val_error\"]\n", + " for error in history[\"val_error\"]\n", "]\n", - "trainer.history[\"lr\"] = [\n", + "history[\"lr\"] = [\n", " lr.item() if isinstance(lr, np.ndarray) and lr.size == 1 else lr\n", - " for lr in trainer.history[\"lr\"]\n", + " for lr in history[\"lr\"]\n", "]\n", "\n", "# Create a DataFrame from the tracker dictionary\n", "tracker_df = pd.DataFrame(\n", " {\n", - " \"epoch\": trainer.history[\"epoch\"],\n", - " \"train_rmse\": trainer.history[\"sqrt_train_error\"],\n", - " \"val_rmse\": trainer.history[\"sqrt_val_error\"],\n", - " \"lr\": trainer.history[\"lr\"],\n", + " \"epoch\": history[\"epoch\"],\n", + " \"train_rmse\": history[\"sqrt_train_error\"],\n", + " \"val_rmse\": history[\"sqrt_val_error\"],\n", + " \"lr\": history[\"lr\"],\n", " }\n", ")\n", "\n", @@ -1661,14 +1683,10 @@ "plt.title(\"Training and Validation Loss Over Epochs\")\n", "plt.legend()\n", "plt.show()\n", - "min_train_error = min(trainer.history[\"sqrt_train_error\"])\n", - "min_val_error = min(trainer.history[\"sqrt_val_error\"])\n", - "epoch_min_train_error = trainer.history[\"epoch\"][\n", - " np.argmin(trainer.history[\"sqrt_train_error\"])\n", - "]\n", - "epoch_min_val_error = trainer.history[\"epoch\"][\n", - " np.argmin(trainer.history[\"sqrt_val_error\"])\n", - "]\n", + "min_train_error = min(history[\"sqrt_train_error\"])\n", + "min_val_error = min(history[\"sqrt_val_error\"])\n", + "epoch_min_train_error = history[\"epoch\"][np.argmin(history[\"sqrt_train_error\"])]\n", + "epoch_min_val_error = history[\"epoch\"][np.argmin(history[\"sqrt_val_error\"])]\n", "print(f\"Minimum Training RMSE: {min_train_error:.5f} (Epoch {epoch_min_train_error})\")\n", "print(f\"Minimum Validation RMSE: {min_val_error:.5f} (Epoch {epoch_min_val_error})\")" ] diff --git a/pyproject.toml b/pyproject.toml index e3dfbcf..6e88b0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "pandas", "seaborn", "torch", + "lightning-fabric>=2.6,<2.7", "torch_geometric>=2.6,<2.7", "torchmetrics", "dill", diff --git a/requirements.txt b/requirements.txt index 7332e87..583cf8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -142,6 +142,7 @@ threadpoolctl==3.6.0 tinycss2==1.4.0 tomli==2.2.1 torch==2.6.0 +lightning-fabric==2.6.1 torch-geometric==2.6.1 torchmetrics==1.8.1 tornado==6.4.2 From 6d37599fe22797a933dba7e716ddb8e19e75a1d9 Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Fri, 6 Mar 2026 14:48:21 +0100 Subject: [PATCH 08/15] performance tweaks --- README.md | 3 + fiora/GNN/fabric_training.py | 21 +++- fiora/cli/train.py | 214 +++++++++++++++++++++++++---------- 3 files changed, 175 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index a927797..4604481 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,9 @@ To persist per-epoch training history, add `--history-out` (supports `.json` or fiora-train ... --history-out checkpoints/fiora_history.json ``` +`pin_memory` is enabled automatically on CUDA; you can override with `--pin-memory` or `--no-pin-memory`. +`--num-workers` is used for both DataLoader workers and parallel preprocessing (thread-based metabolite graph/peak matching setup) in the training CLI. + ### Model Evaluation CLI You can evaluate a trained checkpoint on validation/test splits with: diff --git a/fiora/GNN/fabric_training.py b/fiora/GNN/fabric_training.py index 5234355..dc88b24 100644 --- a/fiora/GNN/fabric_training.py +++ b/fiora/GNN/fabric_training.py @@ -145,6 +145,13 @@ def unwrap_model(model): return model.module if hasattr(model, "module") else model +def move_batch_to_device(batch, device, non_blocking: bool): + try: + return batch.to(device, non_blocking=non_blocking) + except TypeError: + return batch.to(device) + + def run_epoch( fabric: Fabric, model: torch.nn.Module, @@ -162,6 +169,7 @@ def run_epoch( mask_name: str = "validation_mask", show_progress: bool = False, progress_desc: str = "", + non_blocking_transfer: bool = False, ): is_training = optimizer is not None if is_training: @@ -177,7 +185,9 @@ def run_epoch( ) for batch in iterator: - batch = batch.to(fabric.device) + batch = move_batch_to_device( + batch, fabric.device, non_blocking=non_blocking_transfer + ) with torch.set_grad_enabled(is_training): y_pred = model(batch, with_RT=with_rt, with_CCS=with_ccs) @@ -313,9 +323,14 @@ def train_fabric_loop( progress_threshold: int = TQDM_DATA_THRESHOLD, launch_fabric: bool = True, logger: Callable[[str], None] | None = print, + pin_memory: bool | None = None, ): has_validation = len(val_data) > 0 accelerator, devices = resolve_fabric_runtime(device) + if pin_memory is None: + pin_memory = accelerator == "cuda" + use_non_blocking_transfer = bool(pin_memory and accelerator == "cuda") + fabric = Fabric(accelerator=accelerator, devices=devices) if launch_fabric: fabric.launch() @@ -347,6 +362,7 @@ def train_fabric_loop( batch_size=batch_size, num_workers=num_workers, shuffle=True, + pin_memory=pin_memory, ) val_loader = None if has_validation: @@ -355,6 +371,7 @@ def train_fabric_loop( batch_size=batch_size, num_workers=num_workers, shuffle=False, + pin_memory=pin_memory, ) model, optimizer = fabric.setup(model, optimizer) @@ -393,6 +410,7 @@ def train_fabric_loop( optimizer=optimizer, show_progress=show_train_progress, progress_desc=f"Train {epoch}/{epochs}", + non_blocking_transfer=use_non_blocking_transfer, ) is_val_cycle = has_validation and (epoch % val_every == 0) @@ -413,6 +431,7 @@ def train_fabric_loop( mask_name=validation_mask_name, show_progress=show_val_progress, progress_desc=f"Val {epoch}/{epochs}", + non_blocking_transfer=use_non_blocking_transfer, ) else: val_loss = float("nan") diff --git a/fiora/cli/train.py b/fiora/cli/train.py index ebc6913..e67941c 100644 --- a/fiora/cli/train.py +++ b/fiora/cli/train.py @@ -5,6 +5,7 @@ import os import re import warnings +from concurrent.futures import ThreadPoolExecutor import numpy as np import pandas as pd @@ -142,6 +143,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--weight-upper-limit", type=float, default=1000.0) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--num-workers", type=int, default=0) + parser.add_argument( + "--pin-memory", + action=argparse.BooleanOptionalAction, + default=None, + help="Pin host memory for DataLoader (auto: enabled for CUDA).", + ) parser.add_argument("--val-every", type=int, default=1) parser.add_argument( "--use-validation-mask", @@ -220,22 +227,127 @@ def _safe_metabolite(smiles: str): return None -def _build_summary_from_columns(row, metadata_key_map): +def _is_missing_value(val) -> bool: + return val is None or (isinstance(val, float) and np.isnan(val)) + + +def _build_summary_from_record(record: dict, metadata_key_map) -> dict: summary = {} for key, cols in metadata_key_map.items(): if not isinstance(cols, (list, tuple)): cols = [cols] for col in cols: - if col in row.index: - value = row[col] - if value is not None and not ( - isinstance(value, float) and np.isnan(value) - ): + if col in record: + value = record[col] + if not _is_missing_value(value): summary[key] = value break return summary +def _parallel_map(func, tasks, num_workers: int): + if num_workers > 1: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + yield from executor.map(func, tasks) + else: + for task in tasks: + yield func(task) + + +def _progress_iterator(iterable, total: int, desc: str): + try: + from tqdm.auto import tqdm + + return tqdm(iterable, total=total, desc=desc) + except Exception: + return iterable + + +def _prepare_metabolite_task(task): + ( + idx, + record, + smiles_col, + group_id_col, + summary_col, + loss_weight_col, + metadata_key_map, + node_encoder, + bond_encoder, + covariate_encoder, + rt_encoder, + ) = task + + smiles = record.get(smiles_col) + if _is_missing_value(smiles): + return idx, None + + mol = _safe_metabolite(smiles) + if mol is None: + return idx, None + + try: + mol.create_molecular_structure_graph() + mol.compute_graph_attributes(node_encoder, bond_encoder) + except Exception: + return idx, None + + if group_id_col in record: + group_id = record.get(group_id_col) + if not _is_missing_value(group_id): + try: + mol.set_id(int(group_id)) + except Exception: + pass + + summary = record.get(summary_col) if summary_col in record else None + if summary is None: + summary = _build_summary_from_record(record, metadata_key_map) + + try: + mol.add_metadata(summary, covariate_encoder, rt_encoder) + except Exception: + return idx, None + + loss_weight = record.get(loss_weight_col) if loss_weight_col in record else None + if not _is_missing_value(loss_weight): + try: + mol.set_loss_weight(float(loss_weight)) + except Exception: + mol.set_loss_weight(1.0) + else: + mol.set_loss_weight(1.0) + + return idx, mol + + +def _resolve_tolerance(record: dict, ppm_col: str, ppm_default: float) -> float: + tol = ppm_default + if ppm_col in record: + try: + val = float(record[ppm_col]) + if not np.isnan(val): + tol = val + except Exception: + pass + return tol + + +def _match_peaks_task(task): + idx, metabolite, peaks, tol = task + if not isinstance(peaks, dict): + return idx, False + mz = peaks.get("mz") + intensity = peaks.get("intensity") + if mz is None or intensity is None or len(mz) == 0: + return idx, False + try: + metabolite.match_fragments_to_peaks(mz, intensity, tolerance=tol) + return idx, True + except Exception: + return idx, False + + def _resolve_device(device: str) -> str: if device == "auto": return "cuda:0" if torch.cuda.is_available() else "cpu" @@ -385,44 +497,31 @@ def main() -> None: # Build metabolites invalid_rows = [] - for idx, row in df.iterrows(): - smiles = row.get(args.smiles_col) - if smiles is None or (isinstance(smiles, float) and np.isnan(smiles)): - invalid_rows.append(idx) - continue - mol = _safe_metabolite(smiles) + metabolite_tasks = ( + ( + idx, + row.to_dict(), + args.smiles_col, + args.group_id_col, + args.summary_col, + args.loss_weight_col, + metadata_key_map, + node_encoder, + bond_encoder, + covariate_encoder, + rt_encoder, + ) + for idx, row in df.iterrows() + ) + metabolite_results = _parallel_map( + _prepare_metabolite_task, metabolite_tasks, args.num_workers + ) + for idx, mol in _progress_iterator( + metabolite_results, total=len(df), desc="Building graphs" + ): if mol is None: invalid_rows.append(idx) continue - mol.create_molecular_structure_graph() - mol.compute_graph_attributes(node_encoder, bond_encoder) - - if args.group_id_col in df.columns: - try: - mol.set_id(int(row[args.group_id_col])) - except Exception: - pass - - summary = None - if args.summary_col in df.columns: - summary = row.get(args.summary_col) - if summary is None: - summary = _build_summary_from_columns(row, metadata_key_map) - - try: - mol.add_metadata(summary, covariate_encoder, rt_encoder) - except Exception: - invalid_rows.append(idx) - continue - - if args.loss_weight_col in df.columns: - try: - mol.set_loss_weight(float(row[args.loss_weight_col])) - except Exception: - mol.set_loss_weight(1.0) - else: - mol.set_loss_weight(1.0) - df.at[idx, "Metabolite"] = mol if invalid_rows: @@ -443,27 +542,17 @@ def main() -> None: # Match peaks to fragments ppm_default = args.ppm if args.ppm is not None else DEFAULT_PPM match_invalid = [] - for idx, row in df.iterrows(): - peaks = row.get(args.peaks_col) - if not isinstance(peaks, dict): - match_invalid.append(idx) - continue - mz = peaks.get("mz") - intensity = peaks.get("intensity") - if mz is None or intensity is None or len(mz) == 0: - match_invalid.append(idx) - continue - tol = ppm_default - if args.ppm_col in df.columns: - try: - val = float(row[args.ppm_col]) - if not np.isnan(val): - tol = val - except Exception: - pass - try: - row["Metabolite"].match_fragments_to_peaks(mz, intensity, tolerance=tol) - except Exception: + match_tasks = ( + ( + idx, + row["Metabolite"], + row.get(args.peaks_col), + _resolve_tolerance(row, args.ppm_col, ppm_default), + ) + for idx, row in df.iterrows() + ) + for idx, matched in _parallel_map(_match_peaks_task, match_tasks, args.num_workers): + if not matched: match_invalid.append(idx) if match_invalid: @@ -613,6 +702,7 @@ def main() -> None: validation_mask_name=args.validation_mask_name, output_path=output_path, logger=print, + pin_memory=args.pin_memory, ) if args.history_out: _save_history(history, args.history_out) From 2b10b5dad1c7e82739b62da697b527cbd5e313ce Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Mon, 9 Mar 2026 10:16:06 +0100 Subject: [PATCH 09/15] Refactor training pipeline, fix MOL edge-count accumulation, and add MOL regression tests --- README.md | 28 ++ fiora/GNN/fabric_training.py | 480 +++++++++++++++++++++----- fiora/MOL/Metabolite.py | 33 +- fiora/MOL/mol_graph.py | 2 - fiora/MS/SimulationFramework.py | 7 +- fiora/cli/eval.py | 18 +- fiora/cli/train.py | 58 ++++ tests/data/mol_core_100_spectra.jsonl | 100 ++++++ tests/test_mol_core.py | 240 +++++++++++++ 9 files changed, 861 insertions(+), 105 deletions(-) create mode 100644 tests/data/mol_core_100_spectra.jsonl create mode 100644 tests/test_mol_core.py diff --git a/README.md b/README.md index 4604481..4d1b579 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,34 @@ fiora-train ... --history-out checkpoints/fiora_history.json `pin_memory` is enabled automatically on CUDA; you can override with `--pin-memory` or `--no-pin-memory`. `--num-workers` is used for both DataLoader workers and parallel preprocessing (thread-based metabolite graph/peak matching setup) in the training CLI. +For stronger cosine performance, a common setup is: + +```bash +# Stage 1 +fiora-train \ + -i resources/data/msnlib/library.csv \ + -o checkpoints/fiora_stage1.pt \ + --device cuda:0 \ + --instruments HCD \ + --precursor-modes "[M+H]+,[M-H]-,[M]+,[M]-" \ + --hidden-dimension 384 \ + --residual-connections \ + --no-layer-stacking + +# Stage 2 (optional continuation) +fiora-train \ + -i resources/data/msnlib/library.csv \ + -o checkpoints/fiora.pt \ + --resume checkpoints/fiora_stage1.pt \ + --device cuda:0 \ + --instruments HCD \ + --precursor-modes "[M+H]+,[M-H]-,[M]+,[M]-" \ + --loss weighted_mse \ + --y-label compiled_probsSQRT \ + --learning-rate 5e-5 \ + --epochs 30 +``` + ### Model Evaluation CLI You can evaluate a trained checkpoint on validation/test splits with: diff --git a/fiora/GNN/fabric_training.py b/fiora/GNN/fabric_training.py index dc88b24..8278fe8 100644 --- a/fiora/GNN/fabric_training.py +++ b/fiora/GNN/fabric_training.py @@ -1,10 +1,13 @@ import random +import warnings +from dataclasses import dataclass from typing import Callable import numpy as np import torch import torch_geometric.loader as geom_loader from lightning_fabric import Fabric +from lightning_fabric.utilities.warnings import PossibleUserWarning from torchmetrics import MeanSquaredError from fiora.GNN.Losses import WeightedMAELoss, WeightedMSELoss @@ -23,12 +26,16 @@ def build_loss_kwargs( with_weights: bool, mask: torch.Tensor | None = None, include_segment_ptr: bool = True, + weight_tensor_override: torch.Tensor | None = None, ): kwargs = {} if with_weights: - kwargs["weight"] = ( - batch["weight_tensor"] if mask is None else batch["weight_tensor"][mask] + weight_tensor = ( + weight_tensor_override + if weight_tensor_override is not None + else batch["weight_tensor"] ) + kwargs["weight"] = weight_tensor if mask is None else weight_tensor[mask] if include_segment_ptr and getattr(loss_fn, "requires_segment_ptr", False): kwargs["segment_ptr"] = y_pred.get("segment_ptr") return kwargs @@ -145,6 +152,29 @@ def unwrap_model(model): return model.module if hasattr(model, "module") else model +def apply_precursor_loss_weight( + weight_tensor: torch.Tensor, + segment_ptr: torch.Tensor | None, + precursor_loss_weight: float, +) -> torch.Tensor: + if precursor_loss_weight == 1.0: + return weight_tensor + if segment_ptr is None or segment_ptr.numel() < 2: + return weight_tensor + + weighted = weight_tensor.clone() + starts = segment_ptr[:-1] + ends = segment_ptr[1:] + lengths = ends - starts + valid = lengths >= 2 + if torch.any(valid): + right_idx = ends[valid] - 1 + left_idx = ends[valid] - 2 + weighted[right_idx] = weighted[right_idx] * precursor_loss_weight + weighted[left_idx] = weighted[left_idx] * precursor_loss_weight + return weighted + + def move_batch_to_device(batch, device, non_blocking: bool): try: return batch.to(device, non_blocking=non_blocking) @@ -170,6 +200,7 @@ def run_epoch( show_progress: bool = False, progress_desc: str = "", non_blocking_transfer: bool = False, + precursor_loss_weight: float = 1.0, ): is_training = optimizer is not None if is_training: @@ -190,6 +221,16 @@ def run_epoch( ) with torch.set_grad_enabled(is_training): y_pred = model(batch, with_RT=with_rt, with_CCS=with_ccs) + use_weight_vector = with_weights or getattr( + loss_fn, "requires_segment_ptr", False + ) + weight_tensor = None + if use_weight_vector: + weight_tensor = apply_precursor_loss_weight( + batch["weight_tensor"], + y_pred.get("segment_ptr"), + precursor_loss_weight, + ) if use_validation_mask: mask = batch[mask_name] @@ -198,9 +239,10 @@ def run_epoch( batch=batch, y_pred=y_pred, loss_fn=loss_fn, - with_weights=with_weights, + with_weights=use_weight_vector, mask=mask, include_segment_ptr=False, + weight_tensor_override=weight_tensor, ) loss = loss_fn( y_pred["fragment_probs"][mask], @@ -246,8 +288,9 @@ def run_epoch( batch=batch, y_pred=y_pred, loss_fn=loss_fn, - with_weights=with_weights, + with_weights=use_weight_vector, include_segment_ptr=True, + weight_tensor_override=weight_tensor, ) loss = loss_fn(y_pred["fragment_probs"], batch[y_tag], **kwargs) if not rt_metric: @@ -294,6 +337,250 @@ def run_epoch( return avg_loss, metric_label, metric_value +@dataclass +class EpochResult: + loss: float + metric_label: str + metric_value: float + + +@dataclass +class TrainingState: + best_metric: float + best_epoch: int + history: dict + + +def _init_history() -> dict: + return { + "epoch": [], + "train_error": [], + "sqrt_train_error": [], + "val_error": [], + "sqrt_val_error": [], + "lr": [], + } + + +def _record_history( + history: dict, + epoch: int, + lr: float, + train_result: EpochResult | None, + val_result: EpochResult | None, +) -> None: + history["epoch"].append(epoch) + history["train_error"].append( + train_result.metric_value if train_result is not None else float("nan") + ) + history["sqrt_train_error"].append( + train_result.metric_value if train_result is not None else float("nan") + ) + history["val_error"].append( + val_result.metric_value if val_result is not None else float("nan") + ) + history["sqrt_val_error"].append( + val_result.metric_value if val_result is not None else float("nan") + ) + history["lr"].append(lr) + + +def _run_train_epoch( + *, + fabric: Fabric, + model: torch.nn.Module, + dataloader, + loss_fn, + metric, + metric_name: str, + y_label: str, + with_weights: bool, + with_rt: bool, + with_ccs: bool, + rt_metric: bool, + optimizer, + show_progress: bool, + progress_desc: str, + non_blocking_transfer: bool, + precursor_loss_weight: float, +) -> EpochResult: + loss, label, value = run_epoch( + fabric=fabric, + model=model, + dataloader=dataloader, + loss_fn=loss_fn, + metric=metric, + metric_name=metric_name, + y_tag=y_label, + with_weights=with_weights, + with_rt=with_rt, + with_ccs=with_ccs, + rt_metric=rt_metric, + optimizer=optimizer, + show_progress=show_progress, + progress_desc=progress_desc, + non_blocking_transfer=non_blocking_transfer, + precursor_loss_weight=precursor_loss_weight, + ) + return EpochResult(loss=loss, metric_label=label, metric_value=value) + + +def _run_val_epoch( + *, + fabric: Fabric, + model: torch.nn.Module, + dataloader, + loss_fn, + metric, + metric_name: str, + y_label: str, + with_weights: bool, + with_rt: bool, + with_ccs: bool, + rt_metric: bool, + use_validation_mask: bool, + validation_mask_name: str, + show_progress: bool, + progress_desc: str, + non_blocking_transfer: bool, + precursor_loss_weight: float, +) -> EpochResult: + loss, label, value = run_epoch( + fabric=fabric, + model=model, + dataloader=dataloader, + loss_fn=loss_fn, + metric=metric, + metric_name=metric_name, + y_tag=y_label, + with_weights=with_weights, + with_rt=with_rt, + with_ccs=with_ccs, + rt_metric=rt_metric, + use_validation_mask=use_validation_mask, + mask_name=validation_mask_name, + show_progress=show_progress, + progress_desc=progress_desc, + non_blocking_transfer=non_blocking_transfer, + precursor_loss_weight=precursor_loss_weight, + ) + return EpochResult(loss=loss, metric_label=label, metric_value=value) + + +def _monitor_metric( + *, + has_validation: bool, + is_val_cycle: bool, + train_result: EpochResult, + val_result: EpochResult | None, +) -> float | None: + if is_val_cycle and val_result is not None: + return val_result.metric_value + if not has_validation: + return train_result.metric_value + return None + + +def _step_scheduler( + *, + scheduler, + optimizer, + monitor_metric: float | None, + fabric: Fabric, + logger: Callable[[str], None] | None, +) -> None: + if scheduler is None: + return + prev_lr = optimizer.param_groups[0]["lr"] + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + if monitor_metric is not None and not np.isnan(monitor_metric): + scheduler.step(monitor_metric) + else: + scheduler.step() + curr_lr = optimizer.param_groups[0]["lr"] + if logger is not None and fabric.is_global_zero and curr_lr < prev_lr: + logger(f"\t >> Learning rate reduced from {prev_lr:1.0e} to {curr_lr:1.0e}") + + +def _maybe_update_best( + *, + state: TrainingState, + monitor_metric: float | None, + epoch: int, + model, + output_path: str | None, + fabric: Fabric, + logger: Callable[[str], None] | None, + baseline: bool = False, +) -> None: + if monitor_metric is None or np.isnan(monitor_metric): + return + if monitor_metric >= state.best_metric: + return + state.best_metric = monitor_metric + state.best_epoch = epoch + if output_path is not None and fabric.is_global_zero: + unwrap_model(model).save(output_path) + if logger is not None: + if baseline: + logger("\t >> Set baseline checkpoint to epoch 0") + else: + logger(f"\t >> Set new checkpoint to epoch {epoch}") + + +def _log_epoch( + *, + epoch: int, + epochs: int, + train_result: EpochResult | None, + val_result: EpochResult | None, + fabric: Fabric, + logger: Callable[[str], None] | None, +) -> None: + if logger is None or not fabric.is_global_zero: + return + + train_label = ( + train_result.metric_label + if train_result is not None + else (val_result.metric_label if val_result is not None else "metric") + ) + val_label = ( + val_result.metric_label + if val_result is not None + else (train_result.metric_label if train_result is not None else "metric") + ) + + train_loss_str = ( + f"{train_result.loss:.4f}" + if train_result is not None and not np.isnan(train_result.loss) + else "n/a" + ) + val_loss_str = ( + f"{val_result.loss:.4f}" + if val_result is not None and not np.isnan(val_result.loss) + else "n/a" + ) + train_metric_str = ( + f"{train_result.metric_value:.4f}" + if train_result is not None and not np.isnan(train_result.metric_value) + else "n/a" + ) + val_metric_str = ( + f"{val_result.metric_value:.4f}" + if val_result is not None and not np.isnan(val_result.metric_value) + else "n/a" + ) + + logger( + f"Epoch {epoch}/{epochs} - " + f"loss: {train_loss_str} - " + f"val_loss: {val_loss_str} - " + f"train_{train_label}: {train_metric_str} - " + f"val_{val_label}: {val_metric_str}" + ) + + def train_fabric_loop( *, model, @@ -324,9 +611,17 @@ def train_fabric_loop( launch_fabric: bool = True, logger: Callable[[str], None] | None = print, pin_memory: bool | None = None, + precursor_loss_weight: float = 1.0, ): has_validation = len(val_data) > 0 accelerator, devices = resolve_fabric_runtime(device) + warnings.filterwarnings( + "ignore", + message=r"The `srun` command is available on your system but is not used\..*", + category=PossibleUserWarning, + ) + if accelerator == "cuda": + torch.set_float32_matmul_precision("high") if pin_memory is None: pin_memory = accelerator == "cuda" use_non_blocking_transfer = bool(pin_memory and accelerator == "cuda") @@ -383,26 +678,66 @@ def train_fabric_loop( show_train_progress = len(train_data) > progress_threshold show_val_progress = has_validation and (len(val_data) > progress_threshold) - best_metric = float("inf") - best_epoch = -1 - history = { - "epoch": [], - "train_error": [], - "sqrt_train_error": [], - "val_error": [], - "sqrt_val_error": [], - "lr": [], - } + state = TrainingState( + best_metric=float("inf"), best_epoch=-1, history=_init_history() + ) + + if has_validation: + baseline_result = _run_val_epoch( + fabric=fabric, + model=model, + dataloader=val_loader, + loss_fn=loss_fn, + metric=val_metric, + metric_name=metric_name, + y_label=y_label, + with_weights=with_weights, + with_rt=with_rt, + with_ccs=with_ccs, + rt_metric=rt_metric, + use_validation_mask=use_validation_mask, + validation_mask_name=validation_mask_name, + show_progress=show_val_progress, + progress_desc=f"Val 0/{epochs}", + non_blocking_transfer=use_non_blocking_transfer, + precursor_loss_weight=precursor_loss_weight, + ) + if fabric.is_global_zero: + _record_history( + state.history, + epoch=0, + lr=optimizer.param_groups[0]["lr"], + train_result=None, + val_result=baseline_result, + ) + _maybe_update_best( + state=state, + monitor_metric=baseline_result.metric_value, + epoch=0, + model=model, + output_path=output_path, + fabric=fabric, + logger=logger, + baseline=True, + ) + _log_epoch( + epoch=0, + epochs=epochs, + train_result=None, + val_result=baseline_result, + fabric=fabric, + logger=logger, + ) for epoch in range(1, epochs + 1): - train_loss, train_metric_label, train_metric_value = run_epoch( + train_result = _run_train_epoch( fabric=fabric, model=model, dataloader=train_loader, loss_fn=loss_fn, metric=train_metric, metric_name=metric_name, - y_tag=y_label, + y_label=y_label, with_weights=with_weights, with_rt=with_rt, with_ccs=with_ccs, @@ -411,96 +746,83 @@ def train_fabric_loop( show_progress=show_train_progress, progress_desc=f"Train {epoch}/{epochs}", non_blocking_transfer=use_non_blocking_transfer, + precursor_loss_weight=precursor_loss_weight, ) is_val_cycle = has_validation and (epoch % val_every == 0) + val_result = None if is_val_cycle: - val_loss, val_metric_label, val_metric_value = run_epoch( + val_result = _run_val_epoch( fabric=fabric, model=model, dataloader=val_loader, loss_fn=loss_fn, metric=val_metric, metric_name=metric_name, - y_tag=y_label, + y_label=y_label, with_weights=with_weights, with_rt=with_rt, with_ccs=with_ccs, rt_metric=rt_metric, use_validation_mask=use_validation_mask, - mask_name=validation_mask_name, + validation_mask_name=validation_mask_name, show_progress=show_val_progress, progress_desc=f"Val {epoch}/{epochs}", non_blocking_transfer=use_non_blocking_transfer, + precursor_loss_weight=precursor_loss_weight, ) - else: - val_loss = float("nan") - val_metric_label = train_metric_label - val_metric_value = float("nan") - - monitor_metric = None - if is_val_cycle: - monitor_metric = val_metric_value - elif not has_validation: - monitor_metric = train_metric_value - - if scheduler is not None: - prev_lr = optimizer.param_groups[0]["lr"] - if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - if monitor_metric is not None and not np.isnan(monitor_metric): - scheduler.step(monitor_metric) - else: - scheduler.step() - curr_lr = optimizer.param_groups[0]["lr"] - if logger is not None and fabric.is_global_zero and curr_lr < prev_lr: - logger( - f"\t >> Learning rate reduced from {prev_lr:1.0e} to {curr_lr:1.0e}" - ) - if monitor_metric is not None and not np.isnan(monitor_metric): - if monitor_metric < best_metric: - best_metric = monitor_metric - best_epoch = epoch - if output_path is not None and fabric.is_global_zero: - unwrap_model(model).save(output_path) - if logger is not None: - logger(f"\t >> Set new checkpoint to epoch {epoch}") + monitor_metric = _monitor_metric( + has_validation=has_validation, + is_val_cycle=is_val_cycle, + train_result=train_result, + val_result=val_result, + ) + _step_scheduler( + scheduler=scheduler, + optimizer=optimizer, + monitor_metric=monitor_metric, + fabric=fabric, + logger=logger, + ) + _maybe_update_best( + state=state, + monitor_metric=monitor_metric, + epoch=epoch, + model=model, + output_path=output_path, + fabric=fabric, + logger=logger, + baseline=False, + ) if (is_val_cycle or not has_validation) and fabric.is_global_zero: - history["epoch"].append(epoch) - history["train_error"].append(train_metric_value) - history["sqrt_train_error"].append(train_metric_value) - history["val_error"].append( - val_metric_value if is_val_cycle else float("nan") - ) - history["sqrt_val_error"].append( - val_metric_value if is_val_cycle else float("nan") - ) - history["lr"].append(optimizer.param_groups[0]["lr"]) - - if logger is not None and fabric.is_global_zero: - val_loss_str = f"{val_loss:.4f}" if not np.isnan(val_loss) else "n/a" - val_metric_str = ( - f"{val_metric_value:.4f}" if not np.isnan(val_metric_value) else "n/a" - ) - logger( - f"Epoch {epoch}/{epochs} - " - f"loss: {train_loss:.4f} - " - f"val_loss: {val_loss_str} - " - f"train_{train_metric_label}: {train_metric_value:.4f} - " - f"val_{val_metric_label}: {val_metric_str}" + _record_history( + state.history, + epoch=epoch, + lr=optimizer.param_groups[0]["lr"], + train_result=train_result, + val_result=val_result if is_val_cycle else None, ) + _log_epoch( + epoch=epoch, + epochs=epochs, + train_result=train_result, + val_result=val_result, + fabric=fabric, + logger=logger, + ) - if best_epoch < 0: - best_epoch = epochs - best_metric = float("nan") + if state.best_epoch < 0: + state.best_epoch = epochs + state.best_metric = float("nan") if output_path is not None and fabric.is_global_zero: unwrap_model(model).save(output_path) checkpoints = { - "epoch": best_epoch, - "val_loss": best_metric, - "sqrt_val_loss": best_metric, + "epoch": state.best_epoch, + "val_loss": state.best_metric, + "sqrt_val_loss": state.best_metric, "file": output_path, } - return checkpoints, history + return checkpoints, state.history diff --git a/fiora/MOL/Metabolite.py b/fiora/MOL/Metabolite.py index de948d3..e05f5ab 100644 --- a/fiora/MOL/Metabolite.py +++ b/fiora/MOL/Metabolite.py @@ -18,7 +18,6 @@ from fiora.MOL.constants import ( DEFAULT_PPM, - DEFAULT_MODES, DEFAULT_MODE_MAP, ADDUCT_WEIGHTS, ORDERED_ELEMENT_LIST_WITH_HYDROGEN, @@ -27,12 +26,7 @@ from fiora.MOL.mol_graph import ( mol_to_graph, get_adjacency_matrix, - get_degree_matrix, get_edges, - get_identity_matrix, - draw_graph, - compute_edge_related_helper_matrices, - get_helper_matrices_from_edges, ) from fiora.MOL.FragmentationTree import FragmentationTree from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder @@ -453,6 +447,13 @@ def extract_subgraph_features_from_edges(self) -> None: # Store the element composition for the edge self.subgraph_elem_comp[i, :] = edge_elem_comp + @staticmethod + def _edge_count_cols(mode_map, mode_count, ion_mode, break_side): + base_col = mode_map[ion_mode] + if break_side == "left": + return base_col, base_col + mode_count + return base_col + mode_count, base_col + def match_fragments_to_peaks( self, mz_fragments, @@ -521,13 +522,15 @@ def match_fragments_to_peaks( torch.tensor(0.0), ) + mode_count = len(mode_map) self.edge_count_matrix = torch.zeros( - size=(edge_break_labels.shape[0], 2 * len(mode_map)), dtype=torch.float32 + size=(edge_break_labels.shape[0], 2 * mode_count), dtype=torch.float32 ) + get_edge_count_cols = self._edge_count_cols # Determining edge break probabilites from peak intensities. Multiple edges for the same fragment -> divide by number of edges. Multiple fragments from edge -> add intensities. for edge, values in self.edge_intensities: - if edge == None: # precursor + if edge is None: # precursor self.precursor_count += values["intensity"] continue edge_index = ( @@ -549,17 +552,11 @@ def match_fragments_to_peaks( .nonzero() .squeeze() ) - - col = ( - mode_map[values["ion_mode"]] - if values["break_side"] == "left" - else mode_map[values["ion_mode"]] + len(mode_map) + forward_col, backward_col = get_edge_count_cols( + mode_map, mode_count, values["ion_mode"], values["break_side"] ) - self.edge_count_matrix[forward_idx, col] = values["intensity"] - col = (col + len(mode_map)) % ( - 2 * len(mode_map) - ) # to the other side of the break - self.edge_count_matrix[backward_idx, col] = values["intensity"] + self.edge_count_matrix[forward_idx, forward_col] += values["intensity"] + self.edge_count_matrix[backward_idx, backward_col] += values["intensity"] # "bond_features_one_hot", # Compile probability vectors diff --git a/fiora/MOL/mol_graph.py b/fiora/MOL/mol_graph.py index cd61160..8b1727e 100644 --- a/fiora/MOL/mol_graph.py +++ b/fiora/MOL/mol_graph.py @@ -6,7 +6,6 @@ node_color_map = {"C": "gray", "O": "red", "N": "blue"} - edge_color_map = {"SINGLE": "black", "DOUBLE": "black", "AROMATIC": "blue"} edge_width_map = {"SINGLE": 1.5, "DOUBLE": 3, "AROMATIC": 3} @@ -118,7 +117,6 @@ def compute_edge_related_helper_matrices(A, deg): def get_helper_matrices_from_edges(edges, A): AL = torch.zeros(len(edges), A.shape[0]) AR = torch.zeros(AL.shape) - edge_idx = [] for i, (u, v) in enumerate(edges): AL[i, u] = 1.0 diff --git a/fiora/MS/SimulationFramework.py b/fiora/MS/SimulationFramework.py index 59b83d2..164d573 100644 --- a/fiora/MS/SimulationFramework.py +++ b/fiora/MS/SimulationFramework.py @@ -173,7 +173,7 @@ def simulate_spectrum( if transform_prob == "square": max_prob = max(sim_peaks["intensity"]) ** 2 for i in range(len(sim_peaks["intensity"])): - sim_peaks["intensity"][i] == sim_peaks["intensity"][i] ** 2 / max_prob + sim_peaks["intensity"][i] = sim_peaks["intensity"][i] ** 2 / max_prob combined = sorted( zip(sim_peaks["mz"], sim_peaks["intensity"], sim_peaks["annotation"]), @@ -207,11 +207,12 @@ def simulate_and_score( stats["CCS_pred"] = prediction["ccs"].squeeze().tolist() setattr(metabolite, base_attr_name + "_pred", prediction["fragment_probs"]) + training_label = model.model_params.get("training_label") transform_prob = ( "square" if ( - "training_label" in model.model_params - and model.model_params["training_label"] == "compiled_probsSQRT" + training_label == "compiled_probsSQRT" + or (training_label is None and base_attr_name == "compiled_probsSQRT") ) else "None" ) diff --git a/fiora/cli/eval.py b/fiora/cli/eval.py index 548fe13..94a1124 100644 --- a/fiora/cli/eval.py +++ b/fiora/cli/eval.py @@ -70,8 +70,8 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--y-label", - default="compiled_probsALL", - help="Prediction target label used during training.", + default=None, + help="Prediction target label used during training (default: from model params, fallback compiled_probsALL).", ) parser.add_argument( "--min-prob", @@ -316,6 +316,18 @@ def main() -> None: model = _load_model(args.model, dev) model.eval() + y_label = args.y_label or model.model_params.get( + "training_label", "compiled_probsALL" + ) + if args.y_label is None: + print(f"Using y-label from model params: {y_label}") + elif model.model_params.get("training_label") and y_label != model.model_params.get( + "training_label" + ): + print( + "Warning: --y-label does not match model training label " + f"({y_label} vs {model.model_params.get('training_label')})." + ) # Standardize user-configurable column names for downstream code. if args.summary_col != "summary" and args.summary_col in df.columns: @@ -359,7 +371,7 @@ def main() -> None: part = fiora.simulate_all( part, model, - base_attr_name=args.y_label, + base_attr_name=y_label, groundtruth=use_groundtruth, min_intensity=args.min_prob, progress=args.progress, diff --git a/fiora/cli/train.py b/fiora/cli/train.py index e67941c..82573cc 100644 --- a/fiora/cli/train.py +++ b/fiora/cli/train.py @@ -71,11 +71,47 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--learning-rate", type=float, default=2e-4) parser.add_argument("--weight-decay", type=float, default=1e-5) + parser.add_argument( + "--hidden-dimension", + type=int, + default=None, + help="Override model hidden dimension (default from model params).", + ) + parser.add_argument( + "--embedding-dimension", + type=int, + default=None, + help="Override embedding dimension (default from model params).", + ) + parser.add_argument( + "--dense-dim", + type=int, + default=None, + help="Override dense layer hidden dimension (None keeps current setting).", + ) + parser.add_argument( + "--residual-connections", + action=argparse.BooleanOptionalAction, + default=None, + help="Override residual connections setting.", + ) + parser.add_argument( + "--layer-stacking", + action=argparse.BooleanOptionalAction, + default=None, + help="Override layer stacking setting.", + ) parser.add_argument( "--loss", choices=["graphwise_kl", "weighted_mse", "weighted_mae", "mse"], default="graphwise_kl", ) + parser.add_argument( + "--precursor-loss-weight", + type=float, + default=1.0, + help="Multiplier for precursor positions in fragment loss (1.0 keeps original weighting).", + ) parser.add_argument( "--y-label", default="compiled_probsALL", @@ -643,6 +679,26 @@ def main() -> None: "ccs_supported": args.with_ccs, } ) + if args.hidden_dimension is not None: + model_params["hidden_dimension"] = int(args.hidden_dimension) + if args.embedding_dimension is not None: + model_params["embedding_dimension"] = int(args.embedding_dimension) + if args.dense_dim is not None: + model_params["dense_dim"] = int(args.dense_dim) + if args.residual_connections is not None: + model_params["residual_connections"] = bool(args.residual_connections) + if args.layer_stacking is not None: + model_params["layer_stacking"] = bool(args.layer_stacking) + if model_params.get("residual_connections", False): + if ( + model_params.get("hidden_dimension") + != model_params.get("embedding_dimension") + and args.embedding_dimension is None + ): + model_params["embedding_dimension"] = model_params["hidden_dimension"] + if args.dense_dim is None and "dense_dim" not in base_params: + # Avoid shape-mismatch in dense residual blocks when using default params. + model_params["dense_dim"] = None # Initialize or resume model if args.resume: @@ -661,6 +717,7 @@ def main() -> None: raise RuntimeError( "Model does not include RT/CCS heads but --with-rt/--with-ccs was set." ) + model.model_params["training_label"] = args.y_label loss_fn, metric_dict = _choose_loss(args.loss) @@ -703,6 +760,7 @@ def main() -> None: output_path=output_path, logger=print, pin_memory=args.pin_memory, + precursor_loss_weight=args.precursor_loss_weight, ) if args.history_out: _save_history(history, args.history_out) diff --git a/tests/data/mol_core_100_spectra.jsonl b/tests/data/mol_core_100_spectra.jsonl new file mode 100644 index 0000000..c46fe3e --- /dev/null +++ b/tests/data/mol_core_100_spectra.jsonl @@ -0,0 +1,100 @@ +{"SMILES": "CCOc1nc2c(cc(C(=O)NCC3CCCO3)cc2)[nH]1", "group_id": 0, "peaks": {"mz": [132.034439, 133.040298, 188.047424, 244.105377, 259.095642, 260.03833, 260.103394, 260.130463, 260.15387, 260.167358, 260.185913, 288.134888, 288.165802], "intensity": [0.251, 1.832, 0.276, 0.342, 21.103, 0.442, 100.0, 1.727, 1.109, 0.829, 0.265, 31.431, 0.411], "annotation": []}, "summary": {"name": "Z1915955354", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 288.13537, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 289.142641468}} +{"SMILES": "CCOc1nc2c(cc(C(=O)NCC3CCCO3)cc2)[nH]1", "group_id": 0, "peaks": {"mz": [72.032089, 72.036255, 244.105789, 259.095703, 260.103516, 260.130524, 260.16748, 288.134766, 288.165436, 288.209839, 288.230713, 288.253082], "intensity": [0.186, 0.228, 0.272, 3.724, 32.684, 0.537, 0.274, 100.0, 1.287, 0.755, 0.225, 0.18], "annotation": []}, "summary": {"name": "Z1915955354", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 288.13537, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 289.142641468}} +{"SMILES": "CN=S(=O)(c1ccc(NC(=O)N2CCC(F)(F)C2)cc1)N(C)C", "group_id": 1, "peaks": {"mz": [41.998081, 66.034416, 173.540161, 204.070282, 212.085541, 224.076111, 233.096283, 238.064606, 252.037964, 253.102699, 253.129227, 272.043274, 280.983032, 282.070099, 301.069824, 302.077148, 325.022736, 325.113547, 325.184143, 325.203094, 345.019745, 345.119588, 345.155914, 345.19635, 345.217987], "intensity": [0.278, 2.387, 0.746, 1.131, 0.998, 2.86, 8.033, 0.676, 0.458, 17.766, 0.293, 0.902, 0.273, 0.322, 0.26, 0.527, 0.392, 71.423, 0.719, 0.567, 0.501, 100.0, 1.852, 0.777, 0.713], "annotation": []}, "summary": {"name": "3,3-difluoro-N-[4-(trimethyl-S-aminosulfonimidoyl)phenyl]pyrrolidine-1-carboxamide", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 345.12023, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 346.12750332}} +{"SMILES": "CN=S(=O)(c1ccc(NC(=O)N2CCC(F)(F)C2)cc1)N(C)C", "group_id": 1, "peaks": {"mz": [41.998081, 66.034416, 204.070282, 212.085541, 224.076111, 233.096252, 238.064606, 252.037964, 253.102676, 253.129227, 272.043274, 282.070099, 301.069824, 302.077148, 325.022736, 325.113525, 325.184143, 325.203094, 345.019745, 345.119568, 345.155914, 345.19635, 345.217987], "intensity": [0.278, 2.387, 1.131, 0.998, 2.86, 8.033, 0.676, 0.458, 17.766, 0.293, 0.902, 0.322, 0.26, 0.527, 0.392, 71.423, 0.719, 0.567, 0.501, 100.0, 1.852, 0.777, 0.713], "annotation": []}, "summary": {"name": "3,3-difluoro-N-[4-(trimethyl-S-aminosulfonimidoyl)phenyl]pyrrolidine-1-carboxamide", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 345.12023, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 346.12750332}} +{"SMILES": "CCn1nc(C)cc1C(=O)Nc1nc2c(cc(OC)cc2)s1", "group_id": 3, "peaks": {"mz": [57.97522, 120.998901, 148.99382, 162.997406, 190.992188, 242.039719, 256.054321, 285.045166, 287.061157, 299.988129, 300.004639, 300.068817, 300.132874, 300.148895, 300.171448, 315.092377], "intensity": [0.28, 0.986, 2.738, 0.703, 0.232, 0.288, 0.36, 0.752, 0.372, 0.72, 0.416, 100.0, 0.58, 0.717, 0.227, 18.94], "annotation": []}, "summary": {"name": "1-ethyl-N-(6-methoxy-1,3-benzothiazol-2-yl)-3-methyl-1H-pyrazole-5-carboxamide", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 315.09212, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 316.09939675199996}} +{"SMILES": "CCn1nc(C)cc1C(=O)Nc1nc2c(cc(OC)cc2)s1", "group_id": 3, "peaks": {"mz": [112.985268, 300.068573, 315.092041, 315.158936, 315.178192], "intensity": [0.397, 40.117, 100.0, 0.592, 0.655], "annotation": []}, "summary": {"name": "1-ethyl-N-(6-methoxy-1,3-benzothiazol-2-yl)-3-methyl-1H-pyrazole-5-carboxamide", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 315.09212, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 316.09939675199996}} +{"SMILES": "CC1CCCC1NC(=O)N1CCC(Nc2nn3ccnc3cc2)C1", "group_id": 4, "peaks": {"mz": [81.045647, 91.02993, 118.041, 120.056488, 133.051849, 147.067596, 159.067749, 173.544876, 200.094101, 202.109894], "intensity": [7.955, 10.833, 100.0, 3.652, 72.545, 5.198, 14.144, 4.51, 16.18, 59.219], "annotation": []}, "summary": {"name": "Z1760914806", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 327.19388, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 328.201159388}} +{"SMILES": "CC1CCCC1NC(=O)N1CCC(Nc2nn3ccnc3cc2)C1", "group_id": 4, "peaks": {"mz": [173.541, 202.10997, 242.986633, 327.193237], "intensity": [2.697, 100.0, 2.88, 8.084], "annotation": []}, "summary": {"name": "Z1760914806", "collision_energy": 15.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 327.19388, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [15.0], "molecular_weight": 328.201159388}} +{"SMILES": "CC1CCCC1NC(=O)N1CCC(Nc2nn3ccnc3cc2)C1", "group_id": 4, "peaks": {"mz": [202.109955, 242.986862, 327.190643], "intensity": [100.0, 3.384, 1.943], "annotation": []}, "summary": {"name": "Z1760914806", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 327.19388, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 328.201159388}} +{"SMILES": "O=C(Cc1c[nH]c2c1nccc2)N[C@@H]1CC(Cn2cccn2)C[C@H]1O", "group_id": 5, "peaks": {"mz": [41.998059, 42.000187, 67.029595, 67.034111, 117.045212, 129.045273, 130.05278, 131.060866, 131.072144, 131.078522, 138.055573, 147.055677, 157.040008, 173.058548, 173.541443, 174.066681, 268.108307, 270.12439, 320.151093, 336.145172, 338.16179], "intensity": [14.131, 1.328, 28.186, 1.641, 2.013, 7.721, 1.307, 100.0, 2.08, 2.099, 1.173, 42.274, 50.436, 1.7, 2.763, 5.795, 1.731, 2.436, 2.371, 1.948, 28.096], "annotation": []}, "summary": {"name": "Z3687012989", "collision_energy": 45.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 338.16225, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [45.0], "molecular_weight": 339.16952491200004}} +{"SMILES": "O=C(Cc1c[nH]c2c1nccc2)N[C@@H]1CC(Cn2cccn2)C[C@H]1O", "group_id": 5, "peaks": {"mz": [41.998055, 42.000187, 67.029602, 67.034111, 117.045212, 129.045273, 130.05278, 131.060867, 131.072144, 131.078522, 138.055573, 147.055695, 157.040009, 173.058548, 173.541443, 174.066681, 268.108307, 270.12439, 320.151093, 336.145172, 338.161804], "intensity": [14.131, 1.328, 28.186, 1.641, 2.013, 7.721, 1.307, 100.0, 2.08, 2.099, 1.173, 42.274, 50.436, 1.7, 2.763, 5.795, 1.731, 2.436, 2.371, 1.948, 28.096], "annotation": []}, "summary": {"name": "Z3687012989", "collision_energy": 45.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 338.16225, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [45.0], "molecular_weight": 339.16952491200004}} +{"SMILES": "CC(O)(CCc1ccccc1)C(=O)NCc1nc(-c2ccccn2)n[nH]1", "group_id": 6, "peaks": {"mz": [145.051376, 147.081146, 174.078003, 175.06218, 192.102478, 202.073105, 350.162231], "intensity": [3.229, 7.627, 100.0, 5.193, 17.063, 18.329, 31.095], "annotation": []}, "summary": {"name": "Z2298723851", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 350.16225, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 351.1695249120001}} +{"SMILES": "CC(O)(CCc1ccccc1)C(=O)NCc1nc(-c2ccccn2)n[nH]1", "group_id": 6, "peaks": {"mz": [41.014083, 41.998199, 131.049942, 145.051654, 147.081269, 158.059532, 174.07826, 190.048897], "intensity": [34.122, 4.91, 16.005, 9.99, 22.597, 46.794, 100.0, 8.694], "annotation": []}, "summary": {"name": "Z2298723851", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 350.16225, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 351.1695249120001}} +{"SMILES": "CC(O)(CCc1ccccc1)C(=O)NCc1nc(-c2ccccn2)n[nH]1", "group_id": 6, "peaks": {"mz": [173.542816, 173.545273, 174.077977, 175.062576, 192.102509, 202.072827, 350.162354], "intensity": [3.862, 2.434, 21.126, 1.839, 4.319, 4.205, 100.0], "annotation": []}, "summary": {"name": "Z2298723851", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 350.16225, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 351.1695249120001}} +{"SMILES": "CC(O)(CCc1ccccc1)C(=O)NCc1nc(-c2ccccn2)n[nH]1", "group_id": 6, "peaks": {"mz": [41.01403, 131.049576, 145.051102, 147.080887, 158.059158, 173.541885, 174.077805, 175.061981, 192.102142, 202.072449], "intensity": [3.979, 3.076, 7.149, 22.247, 5.92, 5.684, 100.0, 3.467, 5.608, 4.376], "annotation": []}, "summary": {"name": "Z2298723851", "collision_energy": 45.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 350.16225, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [45.0], "molecular_weight": 351.1695249120001}} +{"SMILES": "CC(O)(CCc1ccccc1)C(=O)NCc1nc(-c2ccccn2)n[nH]1", "group_id": 6, "peaks": {"mz": [41.014107, 131.050217, 145.052032, 147.081604, 158.059891, 174.078751, 190.048691], "intensity": [37.972, 17.81, 11.117, 25.146, 52.073, 100.0, 6.043], "annotation": []}, "summary": {"name": "Z2298723851", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 350.16225, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 351.1695249120001}} +{"SMILES": "CC(O)(CCc1ccccc1)C(=O)NCc1nc(-c2ccccn2)n[nH]1", "group_id": 6, "peaks": {"mz": [173.545273, 174.077927, 175.062576, 192.102509, 202.0728, 350.162354], "intensity": [2.434, 21.126, 1.839, 4.319, 4.205, 100.0], "annotation": []}, "summary": {"name": "Z2298723851", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 350.16225, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 351.1695249120001}} +{"SMILES": "COc1cccc(CNC(=O)Cc2cc(O)ccc2)c1", "group_id": 7, "peaks": {"mz": [93.034081, 107.049844, 107.058685, 107.062653, 107.066559, 121.029137, 121.065552, 148.040043, 224.04866, 270.054291, 270.059479, 270.11359], "intensity": [2.02, 71.438, 2.031, 1.632, 0.724, 1.714, 1.767, 15.099, 0.77, 1.61, 1.026, 100.0], "annotation": []}, "summary": {"name": "AKOS010660364", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 270.11357, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 271.120843404}} +{"SMILES": "COc1cccc(CNC(=O)Cc2cc(O)ccc2)c1", "group_id": 7, "peaks": {"mz": [93.034081, 107.049843, 107.058685, 107.062653, 107.066559, 121.029137, 121.065552, 148.040054, 270.059479, 270.113617], "intensity": [2.02, 71.438, 2.031, 1.632, 0.724, 1.714, 1.767, 15.099, 1.026, 100.0], "annotation": []}, "summary": {"name": "AKOS010660364", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 270.11357, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 271.120843404}} +{"SMILES": "CCC(CC)C(=O)N1CCC(NC(=O)Nc2cnc(Oc3cnccc3)cc2)CC1", "group_id": 10, "peaks": {"mz": [173.544418, 186.06752, 410.223088], "intensity": [1.859, 100.0, 7.876], "annotation": []}, "summary": {"name": "Z1534375649", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 410.21976, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 411.227039788}} +{"SMILES": "CCC(CC)C(=O)N1CCC(NC(=O)Nc2cnc(Oc3cnccc3)cc2)CC1", "group_id": 10, "peaks": {"mz": [173.544418, 186.06752, 410.223114], "intensity": [1.859, 100.0, 7.876], "annotation": []}, "summary": {"name": "Z1534375649", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 410.21976, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 411.227039788}} +{"SMILES": "CN(CC(=O)O)S(=O)(=O)c1ccc(C(F)(F)F)cc1", "group_id": 13, "peaks": {"mz": [52.248657, 63.961967, 64.969742, 64.973907, 64.975868, 145.026627, 161.021301, 161.036972, 173.543854, 186.053482, 208.987946, 209.008575, 209.023773, 209.033691, 209.048004, 209.060806, 238.014816, 296.020477, 296.052521, 296.098297], "intensity": [0.239, 1.158, 5.374, 0.35, 0.194, 3.422, 4.533, 0.11, 0.155, 0.118, 100.0, 1.518, 1.128, 0.884, 0.4, 0.16, 1.261, 12.286, 0.221, 0.117], "annotation": []}, "summary": {"name": "1097822-39-3", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 296.02099, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 297.02826346}} +{"SMILES": "CCNC(=O)CN(C)c1nc(C(F)(F)F)nc2ccccc12", "group_id": 14, "peaks": {"mz": [68.995354, 86.060638, 127.029678, 145.076416, 163.048187, 170.072235, 193.057678, 197.032455, 198.039871, 215.04303, 220.044907, 220.06842, 225.051254, 225.074142, 226.002411, 226.058792, 240.074661, 240.099594, 240.119064, 240.130386, 251.030014, 266.054199, 311.076447, 311.111725, 311.16806], "intensity": [0.591, 1.315, 1.605, 1.159, 0.811, 0.43, 1.086, 1.275, 1.047, 5.693, 0.438, 9.788, 47.993, 0.561, 1.075, 1.114, 100.0, 1.013, 1.123, 0.795, 0.499, 1.872, 0.406, 1.478, 8.01], "annotation": []}, "summary": {"name": "Z751549272", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 311.11252, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 312.11979576000004}} +{"SMILES": "CCNC(=O)CN(C)c1nc(C(F)(F)F)nc2ccccc12", "group_id": 14, "peaks": {"mz": [86.060654, 173.539581, 174.955124, 215.043243, 220.068558, 225.051514, 227.009247, 240.074768, 240.099335, 240.119385, 240.130859, 266.054199, 311.111908, 311.16803], "intensity": [0.529, 1.187, 0.519, 0.793, 0.997, 5.677, 0.496, 100.0, 0.858, 0.995, 0.784, 1.635, 25.55, 6.129], "annotation": []}, "summary": {"name": "Z751549272", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 311.11252, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 312.11979576000004}} +{"SMILES": "O=C1CCC[C@H](C(=O)Nc2ccc(-c3ncc4CCCCn43)cc2)N1", "group_id": 15, "peaks": {"mz": [41.998081, 58.029343, 81.034203, 96.045128, 98.060638, 124.040108, 162.983047, 183.08017, 212.1194, 337.168488], "intensity": [21.641, 4.729, 6.692, 15.163, 5.783, 10.445, 3.623, 8.384, 100.0, 24.644], "annotation": []}, "summary": {"name": "Z1996268860", "collision_energy": 45.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 337.167, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [45.0], "molecular_weight": 338.17427594400004}} +{"SMILES": "CS(=O)(=O)N1CCc2c1ccc(C(=O)N1CCCC(C(=O)O)C1)c2", "group_id": 16, "peaks": {"mz": [173.540619, 228.126526, 272.116368, 273.12439, 287.140717, 307.112151, 351.062073, 351.101826, 351.139374], "intensity": [5.895, 2.001, 5.504, 3.781, 2.167, 23.862, 3.791, 100.0, 3.39], "annotation": []}, "summary": {"name": "AKOS022144871", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 351.10202, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 352.10929273999994}} +{"SMILES": "CS(=O)(=O)N1CCc2c1ccc(C(=O)N1CCCC(C(=O)O)C1)c2", "group_id": 16, "peaks": {"mz": [173.540619, 228.126526, 272.116241, 273.12439, 287.140717, 307.112061, 351.101898, 351.139374], "intensity": [5.895, 2.001, 5.504, 3.781, 2.167, 23.862, 100.0, 3.39], "annotation": []}, "summary": {"name": "AKOS022144871", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 351.10202, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 352.10929273999994}} +{"SMILES": "CC(NC(=O)CC1CCCCC1)C(=O)NC1CCN(c2ccccn2)C1", "group_id": 21, "peaks": {"mz": [94.980446, 94.994736, 233.140823, 261.147552, 357.230835], "intensity": [3.739, 1.6, 100.0, 2.458, 17.983], "annotation": []}, "summary": {"name": "AKOS033384282", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 357.2296, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 358.2368762}} +{"SMILES": "CC(NC(=O)CC1CCCCC1)C(=O)NC1CCN(c2ccccn2)C1", "group_id": 21, "peaks": {"mz": [94.980797, 94.995033, 233.140793, 261.146088, 357.229523], "intensity": [5.296, 2.652, 37.19, 2.537, 100.0], "annotation": []}, "summary": {"name": "AKOS033384282", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 357.2296, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 358.2368762}} +{"SMILES": "O=C(O)C1CCN(S(=O)(=O)c2cscc2)CC1", "group_id": 22, "peaks": {"mz": [63.962597, 64.969757, 82.065666, 82.99572, 139.05806, 146.027486, 146.040817, 146.047852, 146.957425, 161.968216, 166.069046, 173.543859, 230.030773, 230.054276, 230.083572, 273.95047, 274.020579, 274.047028, 274.075195, 274.089569, 274.110321], "intensity": [0.277, 4.073, 4.436, 1.507, 1.037, 23.985, 0.558, 0.449, 16.705, 6.298, 0.993, 0.45, 41.411, 0.615, 0.396, 0.409, 100.0, 1.686, 0.992, 0.776, 0.286], "annotation": []}, "summary": {"name": "AKOS026649915", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 274.02132, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 275.02859989600006}} +{"SMILES": "O=C(O)C1CCN(S(=O)(=O)c2cscc2)CC1", "group_id": 22, "peaks": {"mz": [63.962597, 64.969757, 82.065666, 82.99572, 139.05806, 146.027466, 146.040817, 146.047852, 146.957397, 161.968216, 166.069046, 173.543091, 230.030746, 230.054276, 230.083572, 273.95047, 274.020538, 274.047028, 274.075195, 274.089569, 274.110321], "intensity": [0.277, 4.073, 4.436, 1.507, 1.037, 23.985, 0.558, 0.449, 16.705, 6.298, 0.993, 0.45, 41.411, 0.615, 0.396, 0.409, 100.0, 1.686, 0.992, 0.776, 0.286], "annotation": []}, "summary": {"name": "AKOS026649915", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 274.02132, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 275.02859989600006}} +{"SMILES": "Cc1c(C(=O)Nc2cnc(OC(F)F)cc2)cnn1C1CCOCC1", "group_id": 24, "peaks": {"mz": [165.102639, 173.542603, 351.127197], "intensity": [17.76, 4.511, 100.0], "annotation": []}, "summary": {"name": "Z2021219127", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 351.12742, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 352.13469687599996}} +{"SMILES": "Cc1c(C(=O)Nc2cnc(OC(F)F)cc2)cnn1C1CCOCC1", "group_id": 24, "peaks": {"mz": [165.102631, 173.542603, 351.127167], "intensity": [16.366, 4.511, 100.0], "annotation": []}, "summary": {"name": "Z2021219127", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 351.12742, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 352.13469687599996}} +{"SMILES": "CC(NCc1nc(=O)c2c(ccs2)[nH]1)c1ccccc1", "group_id": 26, "peaks": {"mz": [41.998062, 81.975204, 107.990791, 121.993874, 122.005241, 149.988388, 150.996521, 164.004242, 165.012115, 178.00766, 195.993851], "intensity": [83.749, 9.405, 9.507, 55.152, 8.622, 14.937, 8.978, 61.977, 100.0, 10.316, 12.963], "annotation": []}, "summary": {"name": "2-{[(1-phenylethyl)amino]methyl}-3H,4H-thieno[3,2-d]pyrimidin-4-one", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 284.08631, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 285.0935831}} +{"SMILES": "CC(NCc1nc(=O)c2c(ccs2)[nH]1)c1ccccc1", "group_id": 26, "peaks": {"mz": [150.996811, 165.012192, 173.542007, 284.086365], "intensity": [1.284, 10.995, 2.893, 100.0], "annotation": []}, "summary": {"name": "2-{[(1-phenylethyl)amino]methyl}-3H,4H-thieno[3,2-d]pyrimidin-4-one", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 284.08631, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 285.0935831}} +{"SMILES": "CC(NCc1nc(=O)c2c(ccs2)[nH]1)c1ccccc1", "group_id": 26, "peaks": {"mz": [41.998253, 107.99118, 121.994514, 149.989319, 150.997238, 164.005188, 165.01297, 173.544098, 178.008423, 181.979568, 195.995377, 284.086761], "intensity": [25.597, 3.631, 4.583, 7.378, 4.66, 25.048, 100.0, 4.604, 4.794, 5.818, 4.658, 5.326], "annotation": []}, "summary": {"name": "2-{[(1-phenylethyl)amino]methyl}-3H,4H-thieno[3,2-d]pyrimidin-4-one", "collision_energy": 45.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 284.08631, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [45.0], "molecular_weight": 285.0935831}} +{"SMILES": "O=C(NC1CCOCCC1(F)F)c1cc2c(cc(C(=O)O)s2)[nH]1", "group_id": 27, "peaks": {"mz": [122.007553, 165.996933, 173.541977, 206.972207, 216.04982, 229.044225, 231.058624, 241.044159, 259.054683, 273.033966, 279.061096, 285.033813, 293.040741, 299.067108, 303.04467, 323.050796, 342.959045, 343.056905, 343.133453, 343.155182], "intensity": [0.451, 0.537, 0.666, 1.131, 0.247, 6.295, 0.308, 0.876, 10.489, 2.879, 0.807, 0.525, 0.49, 1.094, 11.913, 35.267, 0.581, 100.0, 0.544, 0.651], "annotation": []}, "summary": {"name": "Z3389533917", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 343.05696, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 344.06423436800003}} +{"SMILES": "O=C(NC1CCOCCC1(F)F)c1cc2c(cc(C(=O)O)s2)[nH]1", "group_id": 27, "peaks": {"mz": [122.007553, 165.996933, 206.972946, 216.04982, 229.04425, 231.058624, 241.044159, 259.054718, 273.033966, 279.061096, 285.033813, 293.040741, 299.067108, 303.044708, 323.050812, 342.959045, 343.056915, 343.133453, 343.155182], "intensity": [0.451, 0.537, 0.714, 0.247, 6.295, 0.308, 0.876, 10.489, 2.879, 0.807, 0.525, 0.49, 1.094, 11.913, 35.267, 0.581, 100.0, 0.544, 0.651], "annotation": []}, "summary": {"name": "Z3389533917", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 343.05696, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 344.06423436800003}} +{"SMILES": "COc1ccccc1CNC(=O)C1CCCN(C(=O)c2cc3c(cc2)[nH]nc3)C1", "group_id": 30, "peaks": {"mz": [43.463173, 43.465443, 173.54007, 173.5448, 285.135048, 391.177165], "intensity": [1.567, 2.573, 6.218, 2.656, 25.498, 100.0], "annotation": []}, "summary": {"name": "AKOS034790703", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 391.17757, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 392.1848406280001}} +{"SMILES": "COc1ccccc1CNC(=O)C1CCCN(C(=O)c2cc3c(cc2)[nH]nc3)C1", "group_id": 30, "peaks": {"mz": [43.463173, 43.465443, 173.54007, 285.134979, 391.177032], "intensity": [1.567, 2.573, 6.218, 25.498, 100.0], "annotation": []}, "summary": {"name": "AKOS034790703", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 391.17757, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 392.1848406280001}} +{"SMILES": "CCc1ccccc1C(=O)N1CCC(OC)(C(=O)O)C1", "group_id": 31, "peaks": {"mz": [41.997948, 42.000206, 69.033035, 148.075685, 148.089783, 148.096954, 173.544708, 181.896611, 200.106812, 210.02216, 232.13289, 232.157242, 232.17511, 232.186066, 244.096268, 276.068115, 276.122463, 276.1521, 276.17749, 276.192871], "intensity": [3.02, 0.302, 0.295, 30.118, 0.375, 0.589, 0.311, 0.362, 3.81, 0.488, 81.28, 1.339, 1.016, 0.695, 1.533, 0.572, 100.0, 0.859, 0.896, 0.784], "annotation": []}, "summary": {"name": "AKOS034018460", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 276.12413, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 277.131408088}} +{"SMILES": "CCc1ccccc1C(=O)N1CCC(OC)(C(=O)O)C1", "group_id": 31, "peaks": {"mz": [41.997948, 42.000206, 69.033035, 148.075684, 148.089783, 148.096954, 173.544708, 181.896744, 200.106812, 210.022598, 232.132874, 232.157242, 232.17511, 232.186066, 244.096268, 276.122437, 276.1521, 276.17749, 276.192871], "intensity": [3.02, 0.302, 0.295, 30.118, 0.375, 0.589, 0.311, 0.306, 3.81, 0.395, 81.28, 1.339, 1.016, 0.695, 1.533, 100.0, 0.859, 0.896, 0.784], "annotation": []}, "summary": {"name": "AKOS034018460", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 276.12413, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 277.131408088}} +{"SMILES": "C[C@H]1CN(c2c3[nH]cnc3ncn2)CCN1S(C)(=O)=O", "group_id": 32, "peaks": {"mz": [93.99601, 94.003006, 160.061783, 173.541061, 174.077133, 200.092819, 217.119263, 230.984879, 295.09668, 295.157684, 295.174652], "intensity": [9.303, 0.302, 1.261, 0.792, 1.846, 1.029, 18.561, 0.622, 100.0, 0.983, 0.888], "annotation": []}, "summary": {"name": "Z1601657934", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 295.09827, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 296.10554475199996}} +{"SMILES": "CCN(Cc1ccc(OC)cc1)S(=O)(=O)c1cc(C(=O)O)co1", "group_id": 33, "peaks": {"mz": [57.015583, 57.019012, 63.962391, 64.969574, 64.973808, 64.975571, 65.002571, 84.974571, 228.068527, 228.092255, 228.109528, 228.121017, 228.136627, 228.152695, 228.165909, 230.116974, 294.080597, 338.068848, 338.108063], "intensity": [0.103, 0.217, 0.781, 6.151, 0.355, 0.225, 0.417, 0.632, 100.0, 1.669, 1.351, 0.942, 0.333, 0.155, 0.099, 0.457, 0.113, 8.036, 0.099], "annotation": []}, "summary": {"name": "Z2437274018", "collision_energy": 15.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 338.07038, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [15.0], "molecular_weight": 339.0776582639999}} +{"SMILES": "CCN(Cc1ccc(OC)cc1)S(=O)(=O)c1cc(C(=O)O)co1", "group_id": 33, "peaks": {"mz": [57.019093, 63.961956, 63.966003, 64.965233, 64.969612, 64.973877, 64.975594, 64.977966, 65.002602, 79.956673, 84.974724, 159.081192, 173.54454, 228.069412, 228.111069, 228.122696, 230.118301, 338.069946], "intensity": [0.216, 2.761, 0.132, 0.35, 23.714, 1.197, 0.81, 0.231, 0.419, 0.293, 0.407, 0.11, 0.145, 100.0, 0.515, 0.64, 0.269, 1.05], "annotation": []}, "summary": {"name": "Z2437274018", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 338.07038, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 339.0776582639999}} +{"SMILES": "O=C(NCc1cc(NC(=O)C2CCC2)ccc1)NC1c2ccccc2CC1O", "group_id": 35, "peaks": {"mz": [41.998081, 119.06076, 130.065506, 147.055649, 148.076019, 173.541901, 174.055435, 201.102982, 203.118317, 229.097565, 350.189514, 360.17218], "intensity": [1.344, 2.673, 18.3, 5.713, 22.304, 3.604, 1.484, 2.078, 5.373, 100.0, 1.404, 8.455], "annotation": []}, "summary": {"name": "Z506939324", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 378.18232, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 379.18959165999996}} +{"SMILES": "O=C(NCc1cc(NC(=O)C2CCC2)ccc1)NC1c2ccccc2CC1O", "group_id": 35, "peaks": {"mz": [130.065109, 148.075577, 173.544296, 174.054932, 203.11731, 229.096603, 350.186218, 360.169739, 378.180756], "intensity": [5.921, 15.751, 2.715, 1.55, 2.651, 100.0, 2.327, 37.097, 6.604], "annotation": []}, "summary": {"name": "Z506939324", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 378.18232, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 379.18959165999996}} +{"SMILES": "O=C(O)CNC(=O)c1ccc(NC(=O)CC2C=CCC2)cc1", "group_id": 36, "peaks": {"mz": [92.050285, 107.049929, 125.0605, 149.071633, 164.927139, 173.541061, 173.547852, 200.107858, 200.126358, 230.118225, 255.113967, 257.129635, 301.12017], "intensity": [5.075, 10.029, 2.139, 9.392, 1.69, 3.804, 1.383, 100.0, 1.705, 2.487, 4.03, 29.503, 3.64], "annotation": []}, "summary": {"name": "AKOS026858108", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 301.11938, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 302.126657056}} +{"SMILES": "O=C(O)CNC(=O)c1ccc(NC(=O)CC2C=CCC2)cc1", "group_id": 36, "peaks": {"mz": [107.049583, 136.050858, 149.071433, 164.92717, 173.540863, 200.107512, 230.118362, 255.11441, 257.128937, 301.118757], "intensity": [5.265, 7.921, 10.251, 3.265, 8.801, 47.645, 3.113, 3.449, 100.0, 88.542], "annotation": []}, "summary": {"name": "AKOS026858108", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 301.11938, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 302.126657056}} +{"SMILES": "O=C(O)CNC(=O)c1ccc(NC(=O)CC2C=CCC2)cc1", "group_id": 36, "peaks": {"mz": [92.05043, 107.050064, 125.060539, 149.071838, 173.541061, 200.108261, 255.114334, 257.130432, 301.121643], "intensity": [6.235, 11.569, 2.847, 9.555, 6.901, 100.0, 4.856, 30.496, 5.3], "annotation": []}, "summary": {"name": "AKOS026858108", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 301.11938, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 302.126657056}} +{"SMILES": "O=C(O)CNC(=O)c1ccc(NC(=O)CC2C=CCC2)cc1", "group_id": 36, "peaks": {"mz": [107.049606, 149.071289, 173.540756, 200.107529, 230.118362, 255.11441, 257.128937, 301.118774], "intensity": [6.019, 8.13, 10.061, 43.323, 3.559, 3.942, 100.0, 80.761], "annotation": []}, "summary": {"name": "AKOS026858108", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 301.11938, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 302.126657056}} +{"SMILES": "O=C(NCC(F)(F)F)c1ccccc1NC(=O)[C@@H]1C[C@H]1C(=O)O", "group_id": 37, "peaks": {"mz": [67.018433, 111.008278, 124.001312, 131.014572, 145.040436, 157.039902, 160.076294, 172.076233, 173.544425, 177.046312, 177.065002, 177.074219, 177.082123, 183.055923, 186.055374, 197.052558, 197.072372, 204.065125, 207.055389, 217.058703, 217.080902, 217.096146, 217.107742, 223.050949, 225.066315, 227.061966, 230.045425, 245.072769, 247.068207, 251.045624, 257.054108, 265.078827, 269.056824, 271.052063, 285.085083, 289.063477, 291.058107, 291.087433, 291.134399, 309.068848, 311.064491, 311.098724, 328.983215, 329.074893, 329.109375, 329.145752, 329.16748, 329.193359], "intensity": [0.928, 5.801, 0.291, 0.231, 0.854, 0.678, 0.574, 1.09, 0.376, 43.042, 0.439, 0.641, 0.44, 0.234, 3.185, 10.838, 0.6, 0.204, 0.246, 32.542, 0.431, 0.364, 0.344, 0.424, 7.505, 4.502, 0.512, 1.602, 2.63, 4.376, 0.377, 3.533, 0.534, 0.77, 3.193, 0.22, 46.98, 0.836, 0.372, 0.712, 14.279, 0.217, 0.494, 100.0, 1.863, 0.652, 0.848, 0.298], "annotation": []}, "summary": {"name": "Z2467003114", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 329.07547, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 330.082741556}} +{"SMILES": "O=C(NCC(F)(F)F)c1ccccc1NC(=O)[C@@H]1C[C@H]1C(=O)O", "group_id": 37, "peaks": {"mz": [67.018433, 111.008278, 124.001312, 131.014572, 145.040436, 157.039902, 160.076294, 172.076233, 173.544098, 177.04631, 177.065002, 177.074219, 177.082123, 183.055923, 186.055374, 197.052567, 197.072372, 204.065125, 207.055389, 217.058701, 217.080902, 217.096146, 217.107742, 223.050949, 225.066315, 227.061966, 230.045425, 245.072769, 247.068207, 251.045624, 257.054108, 265.078827, 269.056824, 271.052063, 285.085083, 289.063477, 291.058105, 291.087433, 291.134399, 309.068848, 311.064484, 311.098724, 328.983215, 329.07489, 329.109375, 329.145752, 329.16748, 329.193359], "intensity": [0.928, 5.801, 0.291, 0.231, 0.854, 0.678, 0.574, 1.09, 0.376, 43.042, 0.439, 0.641, 0.44, 0.234, 3.185, 10.838, 0.6, 0.204, 0.246, 32.542, 0.431, 0.364, 0.344, 0.424, 7.505, 4.502, 0.512, 1.602, 2.63, 4.376, 0.377, 3.533, 0.534, 0.77, 3.193, 0.22, 46.98, 0.836, 0.372, 0.712, 14.279, 0.217, 0.494, 100.0, 1.863, 0.652, 0.848, 0.298], "annotation": []}, "summary": {"name": "Z2467003114", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 329.07547, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 330.082741556}} +{"SMILES": "O=C(CCC1CCNCC1)Nc1ccc(-c2ncn[nH]2)cc1", "group_id": 41, "peaks": {"mz": [41.998143, 117.045343, 129.045232, 130.053061, 131.048607, 131.060937, 143.04833, 157.040032, 158.059128, 158.074115, 158.082642, 159.066956, 170.071764, 185.046152, 256.144882, 270.16084, 298.166759], "intensity": [1.138, 4.331, 3.858, 4.655, 4.102, 12.924, 1.335, 1.064, 100.0, 1.334, 2.052, 25.531, 5.301, 4.145, 3.567, 1.056, 7.915], "annotation": []}, "summary": {"name": "Z2437424895", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 298.16733, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 299.174610292}} +{"SMILES": "O=C(CCC1CCNCC1)Nc1ccc(-c2ncn[nH]2)cc1", "group_id": 41, "peaks": {"mz": [41.998222, 117.045273, 129.045151, 130.053009, 131.048615, 131.060913, 143.048309, 157.040024, 158.059113, 158.073837, 158.082642, 159.06694, 170.071686, 185.046219, 256.144897, 270.16098, 298.166565], "intensity": [0.886, 4.334, 3.86, 4.658, 4.105, 11.871, 1.336, 1.015, 100.0, 1.334, 2.053, 25.377, 5.304, 3.383, 3.548, 1.014, 7.921], "annotation": []}, "summary": {"name": "Z2437424895", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 298.16733, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 299.174610292}} +{"SMILES": "Cc1c(C(=O)O)c(NS(=O)(=O)c2cscc2)c(F)cc1", "group_id": 42, "peaks": {"mz": [206.044218, 270.00647, 313.996625], "intensity": [3.506, 100.0, 20.333], "annotation": []}, "summary": {"name": "Z2066021753", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 313.99625, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 315.00352802000003}} +{"SMILES": "Cc1c(C(=O)O)c(NS(=O)(=O)c2cscc2)c(F)cc1", "group_id": 42, "peaks": {"mz": [206.043762, 270.00647, 313.996552], "intensity": [3.506, 100.0, 20.333], "annotation": []}, "summary": {"name": "Z2066021753", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 313.99625, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 315.00352802000003}} +{"SMILES": "O=C(O)CCC1CCCCN1C(=O)NCc1cc(Cl)ccc1", "group_id": 45, "peaks": {"mz": [156.102203, 156.116852, 156.125259, 323.116089], "intensity": [100.0, 1.587, 1.854, 4.458], "annotation": []}, "summary": {"name": "3-(1-{[(3-chlorophenyl)methyl]carbamoyl}piperidin-2-yl)propanoic acid", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 323.11679, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 324.124070212}} +{"SMILES": "CCCCC(Nc1cc(C)nc2ncnn12)C(=O)O", "group_id": 47, "peaks": {"mz": [41.014206, 66.0093, 83.035843, 91.029602, 133.051331, 133.062851, 133.069199, 133.074753, 133.081558, 133.088394, 147.054398, 148.062271, 148.075775, 148.082718, 161.069885, 173.539444, 173.543091, 177.041168, 191.129654, 218.140472, 262.130066], "intensity": [0.249, 0.764, 0.562, 0.567, 100.0, 2.307, 1.748, 0.968, 0.376, 0.215, 0.382, 9.521, 0.223, 0.148, 0.297, 0.366, 0.212, 0.405, 0.214, 4.458, 0.649], "annotation": []}, "summary": {"name": "AKOS009544310", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 262.13095, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 263.138224784}} +{"SMILES": "CCCCC(Nc1cc(C)nc2ncnn12)C(=O)O", "group_id": 47, "peaks": {"mz": [83.035858, 133.051557, 133.062668, 133.069748, 133.074768, 133.081726, 148.062594, 148.075577, 161.070511, 173.539673, 177.041504, 191.130081, 218.141037, 262.131107], "intensity": [0.372, 100.0, 1.603, 1.679, 0.933, 0.389, 9.702, 0.166, 0.241, 0.397, 0.445, 0.181, 4.472, 18.854], "annotation": []}, "summary": {"name": "AKOS009544310", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 262.13095, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 263.138224784}} +{"SMILES": "CCCCC(Nc1cc(C)nc2ncnn12)C(=O)O", "group_id": 47, "peaks": {"mz": [41.014206, 66.0093, 83.035843, 91.029602, 133.051331, 133.062851, 133.069199, 133.074753, 133.081558, 133.088394, 147.054398, 148.062271, 148.075775, 148.082718, 161.069885, 173.543091, 177.041168, 191.129654, 218.140472, 262.130066], "intensity": [0.249, 0.764, 0.562, 0.567, 100.0, 2.307, 1.748, 0.968, 0.376, 0.215, 0.382, 9.521, 0.223, 0.148, 0.297, 0.212, 0.405, 0.214, 4.458, 0.649], "annotation": []}, "summary": {"name": "AKOS009544310", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 262.13095, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 263.138224784}} +{"SMILES": "CCCCC(Nc1cc(C)nc2ncnn12)C(=O)O", "group_id": 47, "peaks": {"mz": [83.035858, 133.051559, 133.062668, 133.069748, 133.074768, 133.081726, 148.062607, 148.075577, 161.070511, 177.041504, 191.130081, 218.141037, 262.131134], "intensity": [0.372, 100.0, 1.603, 1.679, 0.933, 0.389, 9.702, 0.166, 0.241, 0.445, 0.181, 4.472, 18.854], "annotation": []}, "summary": {"name": "AKOS009544310", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 262.13095, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 263.138224784}} +{"SMILES": "Cc1nsc2c1cc(C(=O)N1CCC(C(=O)O)C1)cn2", "group_id": 48, "peaks": {"mz": [41.998093, 42.000252, 81.975365, 96.045227, 123.001801, 149.002579, 149.017242, 149.031387, 149.038773, 149.044952, 149.053101, 173.538589, 192.023087, 193.007126, 205.030869, 237.021149, 246.069977, 246.095764, 246.128784, 246.146042, 260.049347, 290.060089], "intensity": [2.162, 0.197, 0.311, 0.141, 1.183, 0.369, 100.0, 1.588, 1.831, 1.01, 0.386, 0.533, 5.39, 0.606, 0.617, 0.279, 57.765, 0.839, 0.475, 0.155, 1.047, 8.328], "annotation": []}, "summary": {"name": "Z1455174573", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 290.06049, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 291.06776227600005}} +{"SMILES": "Cc1nsc2c1cc(C(=O)N1CCC(C(=O)O)C1)cn2", "group_id": 48, "peaks": {"mz": [41.998348, 149.018326, 192.024323, 245.995895, 246.012192, 246.046387, 246.071579, 260.051117, 289.962921, 289.985718, 290.002686, 290.028625, 290.061829], "intensity": [0.641, 26.887, 3.827, 0.251, 0.826, 1.321, 100.0, 2.176, 0.289, 0.746, 1.002, 0.849, 98.369], "annotation": []}, "summary": {"name": "Z1455174573", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 290.06049, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 291.06776227600005}} +{"SMILES": "O=C(O)CCC1CCN(C(=O)NC2CCCC2)C1", "group_id": 51, "peaks": {"mz": [142.073013, 142.086914, 142.100525, 142.10704, 142.112595, 142.120239, 142.128128, 173.54155, 253.085281, 253.155975], "intensity": [0.307, 100.0, 0.939, 1.514, 0.933, 0.302, 0.22, 0.289, 0.44, 16.415], "annotation": []}, "summary": {"name": "AKOS020295868", "collision_energy": 15.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 253.15577, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [15.0], "molecular_weight": 254.163042564}} +{"SMILES": "O=C(O)CCC1CCN(C(=O)NC2CCCC2)C1", "group_id": 51, "peaks": {"mz": [142.086609, 142.099792, 142.106705, 142.112518, 142.119888, 142.127777, 253.08374, 253.155182], "intensity": [100.0, 1.228, 2.006, 0.916, 0.399, 0.258, 0.226, 3.154], "annotation": []}, "summary": {"name": "AKOS020295868", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 253.15577, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 254.163042564}} +{"SMILES": "CC(=O)Nc1nc(C(=O)Nc2cc3c(cc2)CCN3C)cs1", "group_id": 53, "peaks": {"mz": [57.975141, 107.025017, 141.012366, 173.545044, 273.081421, 315.092259], "intensity": [6.452, 0.992, 20.236, 1.268, 1.682, 100.0], "annotation": []}, "summary": {"name": "Z1595247480", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 315.09212, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 316.09939675199996}} +{"SMILES": "CC(=O)Nc1nc(C(=O)Nc2cc3c(cc2)CCN3C)cs1", "group_id": 53, "peaks": {"mz": [57.975124, 107.025017, 141.01239, 273.081421, 315.092285], "intensity": [6.452, 0.992, 20.236, 1.682, 100.0], "annotation": []}, "summary": {"name": "Z1595247480", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 315.09212, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 316.09939675199996}} +{"SMILES": "O=C(NCC1(Cc2ccc(F)cc2)CC1)N1CCn2ncnc2C1", "group_id": 54, "peaks": {"mz": [68.02494, 81.032784, 82.040604, 121.051552, 123.067184, 123.077782, 123.083031, 123.088013, 173.542755, 209.037979], "intensity": [3.444, 4.907, 2.685, 3.908, 100.0, 2.103, 1.914, 1.1, 1.227, 0.751], "annotation": []}, "summary": {"name": "Z1719625346", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 328.15791, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 329.16518848000004}} +{"SMILES": "O=C(NCC1(Cc2ccc(F)cc2)CC1)N1CCn2ncnc2C1", "group_id": 54, "peaks": {"mz": [68.025024, 81.032806, 121.051453, 123.067055, 123.077805, 123.083046, 328.15741], "intensity": [0.703, 0.823, 2.875, 100.0, 2.517, 2.36, 6.67], "annotation": []}, "summary": {"name": "Z1719625346", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 328.15791, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 329.16518848000004}} +{"SMILES": "O=C1CCCN1c1ccc(S(=O)(=O)Nc2cc3c(nc2)OCC3)cc1", "group_id": 56, "peaks": {"mz": [41.998269, 61.970346, 63.962212, 91.042786, 92.050463, 107.037646, 108.04539, 133.040616, 134.048376, 135.032394, 135.056258, 160.076635, 173.538315, 173.545624, 176.071543, 182.085068, 197.003281, 224.03863, 239.119232, 251.118602, 267.113464, 292.109906, 293.116961, 294.124633, 315.08136, 358.087342], "intensity": [25.623, 10.464, 19.309, 2.189, 4.982, 9.048, 17.835, 3.197, 27.428, 21.482, 80.475, 100.0, 9.384, 2.695, 14.167, 3.576, 2.362, 32.49, 2.19, 5.611, 2.13, 2.907, 2.87, 4.099, 2.383, 5.692], "annotation": []}, "summary": {"name": "Z3633261485", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 358.0867, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 359.093977024}} +{"SMILES": "O=C1CCCN1c1ccc(S(=O)(=O)Nc2cc3c(nc2)OCC3)cc1", "group_id": 56, "peaks": {"mz": [135.055908, 160.076172, 173.542358, 224.03746, 358.085999], "intensity": [34.925, 44.129, 79.451, 36.752, 100.0], "annotation": []}, "summary": {"name": "Z3633261485", "collision_energy": 45.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 358.0867, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [45.0], "molecular_weight": 359.093977024}} +{"SMILES": "O=C1CCCN1c1ccc(S(=O)(=O)Nc2cc3c(nc2)OCC3)cc1", "group_id": 56, "peaks": {"mz": [41.99844, 61.970581, 63.96246, 91.042786, 92.050804, 107.037979, 108.045784, 133.041046, 134.048828, 135.032822, 135.056656, 160.077118, 173.545624, 176.072052, 182.085068, 197.003281, 224.039078, 239.119232, 251.11911, 267.113464, 292.110229, 293.116577, 294.125519, 315.08136, 358.088013], "intensity": [26.522, 10.831, 19.987, 2.266, 5.157, 9.366, 18.461, 3.309, 28.159, 21.394, 83.3, 100.0, 2.79, 14.664, 3.702, 2.445, 33.63, 2.267, 5.014, 2.205, 3.01, 2.94, 4.243, 2.466, 5.892], "annotation": []}, "summary": {"name": "Z3633261485", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 358.0867, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 359.093977024}} +{"SMILES": "CN(Cc1cc(O)ccc1)C(=O)CCc1c[nH]c2ccccc12", "group_id": 57, "peaks": {"mz": [76.788887, 102.956288, 160.076218, 173.541465, 178.086765, 178.103653, 178.115067, 196.076694, 211.079146, 260.931256, 263.150787, 278.942759, 307.144621, 307.177917, 307.227325], "intensity": [0.233, 0.479, 1.205, 0.64, 22.342, 0.341, 0.363, 0.411, 1.562, 0.357, 0.268, 0.473, 100.0, 1.349, 0.829], "annotation": []}, "summary": {"name": "Z608788638", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 307.1452, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 308.15247788}} +{"SMILES": "CN(Cc1cc(O)ccc1)C(=O)CCc1c[nH]c2ccccc12", "group_id": 57, "peaks": {"mz": [63.962082, 72.044975, 72.049866, 72.051956, 92.026329, 93.034088, 95.049759, 105.034119, 106.041946, 107.049751, 107.058617, 111.024628, 116.050125, 119.049736, 120.044991, 121.028984, 121.065468, 123.044952, 133.028809, 134.060501, 136.076202, 136.088638, 136.094513, 144.045044, 145.052719, 147.044342, 147.117447, 158.060822, 160.076187, 160.090973, 162.055466, 163.06337, 178.086716, 178.103699, 178.114883, 178.12265, 211.079285, 214.92569, 307.144928], "intensity": [0.296, 13.332, 0.676, 0.419, 0.357, 0.78, 1.734, 0.481, 2.444, 12.845, 0.389, 0.571, 0.654, 3.225, 1.252, 1.22, 1.071, 0.339, 0.509, 1.651, 16.708, 0.298, 0.347, 1.359, 4.916, 0.364, 0.441, 1.324, 24.294, 0.482, 1.025, 0.766, 100.0, 1.734, 1.547, 0.984, 0.657, 1.136, 0.333], "annotation": []}, "summary": {"name": "Z608788638", "collision_energy": 45.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 307.1452, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [45.0], "molecular_weight": 308.15247788}} +{"SMILES": "CN(Cc1cc(O)ccc1)C(=O)CCc1c[nH]c2ccccc12", "group_id": 57, "peaks": {"mz": [76.788887, 102.956528, 160.076218, 173.541458, 178.086761, 178.103653, 178.115067, 196.076508, 211.0793, 260.930878, 263.150787, 278.943298, 307.144623, 307.177917, 307.227325], "intensity": [0.233, 0.252, 1.205, 0.64, 22.342, 0.341, 0.363, 0.295, 1.562, 0.357, 0.268, 0.473, 100.0, 1.349, 0.829], "annotation": []}, "summary": {"name": "Z608788638", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 307.1452, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 308.15247788}} +{"SMILES": "CS(=O)(=O)N1CCCC(NC(=O)c2cc3ccccc3[nH]2)C1", "group_id": 58, "peaks": {"mz": [41.998062, 42.000294, 42.001259, 63.961933, 63.966064, 76.969711, 78.972801, 78.985344, 78.991096, 93.960014, 93.996277, 94.00338, 94.006645, 94.010765, 115.042198, 116.039597, 116.050018, 116.06015, 116.06469, 116.069275, 116.074783, 125.071342, 141.045242, 158.060593, 159.055786, 173.541885, 177.06987, 185.071564, 207.092041, 225.101868, 242.129211], "intensity": [10.087, 0.972, 0.363, 5.845, 0.346, 0.559, 7.133, 7.241, 0.333, 0.435, 41.457, 1.596, 1.246, 0.419, 0.324, 0.644, 100.0, 2.261, 2.424, 1.015, 0.408, 0.513, 1.975, 0.359, 2.508, 0.626, 0.698, 0.341, 0.346, 0.312, 4.059], "annotation": []}, "summary": {"name": "Z828576222", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 320.10744, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 321.11471246800005}} +{"SMILES": "CS(=O)(=O)N1CCCC(NC(=O)c2cc3ccccc3[nH]2)C1", "group_id": 58, "peaks": {"mz": [41.99807, 42.00024, 63.961891, 78.972748, 78.985367, 93.996277, 94.003372, 94.006653, 94.010612, 94.980331, 116.050095, 141.045181, 159.041702, 159.05571, 159.072189, 160.042709, 173.539093, 177.069534, 203.048782, 207.091965, 213.102707, 225.102585, 225.125748, 240.113632, 242.129135, 242.153946, 242.173553, 242.186829, 242.204132, 277.100647, 320.106964], "intensity": [2.268, 0.265, 0.806, 0.509, 5.687, 47.29, 1.858, 1.409, 0.5, 1.346, 11.791, 0.359, 0.294, 9.938, 0.268, 0.35, 0.923, 0.402, 0.868, 0.417, 1.861, 22.255, 0.355, 0.708, 100.0, 1.64, 1.063, 0.835, 0.36, 0.642, 43.486], "annotation": []}, "summary": {"name": "Z828576222", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 320.10744, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 321.11471246800005}} +{"SMILES": "CS(=O)(=O)N1CCCC(NC(=O)c2cc3ccccc3[nH]2)C1", "group_id": 58, "peaks": {"mz": [78.985466, 80.029846, 93.996338, 94.003342, 94.006706, 94.980408, 116.05027, 159.055832, 203.048813, 213.102554, 225.102692, 225.129257, 240.113815, 242.129227, 242.153885, 320.018219, 320.106995, 320.177246, 320.194977, 320.21994], "intensity": [0.655, 0.232, 7.344, 0.27, 0.224, 0.674, 0.958, 1.251, 0.249, 0.27, 5.42, 0.193, 0.252, 26.195, 0.435, 0.541, 100.0, 0.824, 0.872, 0.273], "annotation": []}, "summary": {"name": "Z828576222", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 320.10744, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 321.11471246800005}} +{"SMILES": "CCCOc1c(Cl)cc(NC(=O)c2cncc(-c3cn(C)nc3)c2)cc1", "group_id": 59, "peaks": {"mz": [81.517517, 95.992447, 167.985214, 243.132507, 323.10733, 325.965515, 326.057098, 326.091827, 326.127136, 326.147919, 326.174561, 369.112213], "intensity": [0.236, 0.183, 10.935, 0.619, 1.663, 0.531, 100.0, 1.621, 0.749, 0.709, 0.304, 2.988], "annotation": []}, "summary": {"name": "N-(3-chloro-4-propoxyphenyl)-5-(1-methyl-1H-pyrazol-4-yl)pyridine-3-carboxamide", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 369.11238, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 370.1196535280001}} +{"SMILES": "CCCOc1c(Cl)cc(NC(=O)c2cncc(-c3cn(C)nc3)c2)cc1", "group_id": 59, "peaks": {"mz": [41.011387, 41.013626, 167.986145, 321.11084, 323.107605, 326.057037, 326.092194, 326.126648, 326.148071, 369.112, 369.220154], "intensity": [0.772, 1.976, 0.9, 0.495, 3.852, 100.0, 1.529, 0.762, 0.756, 72.513, 0.542], "annotation": []}, "summary": {"name": "N-(3-chloro-4-propoxyphenyl)-5-(1-methyl-1H-pyrazol-4-yl)pyridine-3-carboxamide", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 369.11238, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 370.1196535280001}} +{"SMILES": "Cc1c(C(=O)NCC(=O)NC2CCCCC2)cccn1", "group_id": 61, "peaks": {"mz": [92.049957, 149.071182, 153.102585, 174.091507, 175.050644, 181.097397, 274.154968], "intensity": [47.366, 32.427, 48.147, 28.604, 32.777, 78.365, 100.0], "annotation": []}, "summary": {"name": "Z1155100127", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 274.1561, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 275.163376912}} +{"SMILES": "Cc1c(C(=O)NCC(=O)NC2CCCCC2)cccn1", "group_id": 61, "peaks": {"mz": [41.998219, 92.050194, 131.061447, 136.039871, 147.05603, 149.071579, 153.102892, 173.538132, 173.543854, 175.050827, 181.097795, 181.117508, 274.085114, 274.155688, 274.18219, 274.209747, 274.22583], "intensity": [0.218, 2.471, 0.219, 0.415, 0.598, 5.844, 5.983, 0.602, 0.342, 3.319, 8.388, 0.194, 0.506, 100.0, 1.583, 0.678, 0.752], "annotation": []}, "summary": {"name": "Z1155100127", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 274.1561, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 275.163376912}} +{"SMILES": "Cc1c(C(=O)NCC(=O)NC2CCCCC2)cccn1", "group_id": 61, "peaks": {"mz": [41.998219, 92.050194, 131.061447, 136.039871, 147.05603, 149.071564, 153.102844, 173.538132, 175.050827, 181.097778, 181.117508, 274.085114, 274.15567, 274.18219, 274.209747, 274.22583], "intensity": [0.218, 2.471, 0.219, 0.415, 0.598, 5.844, 5.983, 0.602, 3.319, 8.388, 0.194, 0.506, 100.0, 1.583, 0.678, 0.752], "annotation": []}, "summary": {"name": "Z1155100127", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 274.1561, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 275.163376912}} +{"SMILES": "NC(=O)NC(Cc1ccccc1)C(=O)NCC1CN2CCCC2CO1", "group_id": 63, "peaks": {"mz": [41.99807, 173.545166, 280.98233, 285.160675, 302.186768, 345.156525], "intensity": [29.929, 12.737, 9.159, 19.049, 100.0, 15.467], "annotation": []}, "summary": {"name": "AKOS032976537", "collision_energy": 15.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 345.19322, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [15.0], "molecular_weight": 346.20049069199996}} +{"SMILES": "NC(=O)NC(Cc1ccccc1)C(=O)NCC1CN2CCCC2CO1", "group_id": 63, "peaks": {"mz": [41.998322, 210.125397, 280.985107, 285.162079, 302.18866, 345.157562], "intensity": [30.412, 13.378, 10.879, 22.309, 100.0, 14.19], "annotation": []}, "summary": {"name": "AKOS032976537", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 345.19322, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 346.20049069199996}} +{"SMILES": "CC1(C(=O)O)CCCN(C(=O)CCC(F)(F)F)C1", "group_id": 64, "peaks": {"mz": [99.011973, 100.019805, 142.086644, 142.100118, 142.106638, 162.091751, 173.544662, 182.09781, 202.104081, 266.100226], "intensity": [2.972, 1.245, 100.0, 1.557, 1.965, 3.92, 0.754, 5.956, 1.463, 5.791], "annotation": []}, "summary": {"name": "AKOS015050168", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 266.10095, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 267.108228032}} +{"SMILES": "CC1(C(=O)O)CCCN(C(=O)CCC(F)(F)F)C1", "group_id": 64, "peaks": {"mz": [142.086931, 142.100407, 142.106888, 182.098404, 202.104614, 266.10102], "intensity": [100.0, 0.892, 1.659, 2.875, 1.755, 70.336], "annotation": []}, "summary": {"name": "AKOS015050168", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 266.10095, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 267.108228032}} +{"SMILES": "CC1(C(=O)O)CCCN(C(=O)CCC(F)(F)F)C1", "group_id": 64, "peaks": {"mz": [99.012001, 100.019798, 142.086655, 142.100113, 142.106644, 162.091751, 173.544662, 182.097794, 202.104187, 266.100281], "intensity": [2.972, 1.245, 100.0, 1.557, 1.965, 3.92, 0.754, 5.956, 1.463, 5.791], "annotation": []}, "summary": {"name": "AKOS015050168", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 266.10095, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 267.108228032}} +{"SMILES": "CC1(C(=O)O)CCCN(C(=O)CCC(F)(F)F)C1", "group_id": 64, "peaks": {"mz": [142.086884, 142.100113, 142.106888, 182.098221, 202.104614, 266.100922], "intensity": [100.0, 0.892, 1.659, 2.875, 1.755, 70.336], "annotation": []}, "summary": {"name": "AKOS015050168", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 266.10095, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 267.108228032}} +{"SMILES": "COc1cc2c(cc1)nc(NC(=O)CC1CCNCC1)s2", "group_id": 67, "peaks": {"mz": [173.54068, 179.028076, 289.089172, 304.112549], "intensity": [11.742, 9.286, 19.236, 100.0], "annotation": []}, "summary": {"name": "Z733322042", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 304.11252, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 305.119797848}} +{"SMILES": "Cc1c(Cl)cc(C(=O)NCC(C)(O)C(=O)O)cc1", "group_id": 71, "peaks": {"mz": [41.998109, 42.000336, 44.997762, 45.000233, 56.013706, 57.034111, 59.013439, 70.029381, 87.008286, 98.024317, 98.031883, 98.035339, 100.039932, 125.015991, 168.021706, 168.03714, 168.047623, 168.054779, 169.00574, 173.544586, 182.03714, 183.02153, 195.021561, 197.037033, 206.037598, 208.052982, 210.032254, 223.995987, 224.02298, 224.047841, 224.069138, 224.087631, 224.099533, 224.114349, 225.031921, 234.032059, 252.042993, 270.053511, 270.121155], "intensity": [9.266, 0.725, 5.144, 0.373, 0.241, 1.159, 0.565, 3.455, 1.22, 16.835, 0.541, 0.405, 0.475, 1.517, 29.084, 0.516, 0.386, 0.323, 4.795, 0.278, 0.725, 0.259, 1.268, 0.484, 0.25, 3.651, 1.618, 0.409, 0.199, 100.0, 1.729, 0.858, 0.867, 0.315, 1.379, 0.25, 8.831, 36.885, 0.283], "annotation": []}, "summary": {"name": "AKOS015515696", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 270.05386, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 271.06113560800003}} +{"SMILES": "Cc1c(Cl)cc(C(=O)NCC(C)(O)C(=O)O)cc1", "group_id": 71, "peaks": {"mz": [41.998112, 44.997747, 45.000225, 67.515724, 70.029449, 87.008461, 98.02433, 100.040024, 168.021675, 169.00563, 173.545349, 208.052918, 210.032867, 224.047744, 224.070679, 224.087326, 225.03183, 252.042867, 269.964264, 270.053266, 270.08139, 270.106689, 270.121307, 270.139984, 270.160126], "intensity": [0.42, 1.58, 0.152, 0.204, 0.146, 0.204, 0.752, 0.362, 6.157, 2.351, 0.132, 1.011, 0.131, 29.572, 0.454, 0.367, 0.13, 3.759, 0.142, 100.0, 1.504, 0.989, 0.885, 0.263, 0.145], "annotation": []}, "summary": {"name": "AKOS015515696", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 270.05386, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 271.06113560800003}} +{"SMILES": "Cc1c(Cl)cc(C(=O)NCC(C)(O)C(=O)O)cc1", "group_id": 71, "peaks": {"mz": [41.998112, 42.000336, 42.001381, 44.997765, 45.000233, 56.013706, 57.034111, 59.013439, 70.029388, 87.008286, 98.024323, 98.031883, 98.035339, 100.039932, 125.015991, 168.021713, 168.03714, 168.047623, 168.054779, 169.005753, 182.03714, 183.02153, 195.021561, 197.037033, 206.037598, 208.052963, 210.032196, 223.995987, 224.047852, 224.069138, 224.087631, 224.099533, 224.114349, 225.031921, 234.032059, 252.043015, 270.053528, 270.121155], "intensity": [9.266, 0.725, 0.278, 5.144, 0.373, 0.241, 1.159, 0.565, 3.455, 1.22, 16.835, 0.541, 0.405, 0.475, 1.517, 29.084, 0.516, 0.386, 0.323, 4.795, 0.725, 0.259, 1.268, 0.484, 0.25, 3.651, 1.618, 0.409, 100.0, 1.729, 0.858, 0.867, 0.315, 1.379, 0.25, 8.831, 36.885, 0.283], "annotation": []}, "summary": {"name": "AKOS015515696", "collision_energy": 30.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 270.05386, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [30.0], "molecular_weight": 271.06113560800003}} +{"SMILES": "Cc1c(Cl)cc(C(=O)NCC(C)(O)C(=O)O)cc1", "group_id": 71, "peaks": {"mz": [41.998112, 44.997738, 45.000225, 67.515724, 70.029449, 87.008461, 98.02433, 100.040024, 168.021667, 169.005646, 173.545319, 208.052887, 210.032867, 224.047745, 224.070679, 224.087326, 225.03183, 252.042847, 269.964264, 270.053253, 270.08139, 270.106689, 270.121307, 270.139984, 270.160126], "intensity": [0.42, 1.58, 0.152, 0.204, 0.146, 0.204, 0.752, 0.362, 6.157, 2.351, 0.132, 1.011, 0.131, 29.572, 0.454, 0.367, 0.13, 3.759, 0.142, 100.0, 1.504, 0.989, 0.885, 0.263, 0.145], "annotation": []}, "summary": {"name": "AKOS015515696", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 270.05386, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 271.06113560800003}} +{"SMILES": "Cc1nccc(C(=O)N2CCC(CCC(=O)O)CC2)c1", "group_id": 76, "peaks": {"mz": [92.050011, 95.086029, 135.055771, 138.091599, 139.075745, 173.540588, 231.149414, 275.068481, 275.139374, 275.208496], "intensity": [5.183, 0.833, 2.454, 0.573, 0.786, 0.871, 10.733, 0.513, 100.0, 0.722], "annotation": []}, "summary": {"name": "Z2910698425", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 275.14012, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 276.1473925}} +{"SMILES": "COc1ccc(-c2c(C)sc(C(=O)N3CCN(C(=O)c4cc(O)ccc4)CC3)c2)cc1", "group_id": 78, "peaks": {"mz": [41.9981, 93.034119, 118.02948, 132.044937, 135.044693, 136.04007, 146.024261, 147.03212, 162.055527, 173.047653, 173.538391, 187.022171, 188.029556, 188.071075, 189.037277, 203.052994, 203.081619, 205.097733, 214.008682, 231.07692, 247.042694, 272.074707, 289.064636, 299.086212, 419.106598, 420.114349, 435.13916], "intensity": [8.244, 2.299, 1.263, 2.801, 1.997, 1.228, 1.688, 1.857, 40.057, 2.822, 5.882, 3.628, 36.05, 19.273, 2.965, 100.0, 11.182, 4.388, 1.406, 20.251, 2.831, 10.01, 2.33, 4.045, 2.798, 3.065, 10.398], "annotation": []}, "summary": {"name": "Z1007507134", "collision_energy": 45.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 435.1384, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [45.0], "molecular_weight": 436.14567824799997}} +{"SMILES": "COc1cc2c(cc1)[nH]cc2CC(=O)N1CCC(CS(C)(=O)=O)C1", "group_id": 79, "peaks": {"mz": [63.961926, 63.966057, 78.985371, 78.991112, 78.993507, 93.00103, 131.037094, 144.044856, 144.058899, 144.065353, 144.071259, 145.052765, 158.06015, 171.03241, 172.039825, 173.543686, 184.076828, 201.066132, 211.087555, 213.102722, 216.921585, 239.08165, 241.09758, 253.098328, 255.113098, 333.090912, 334.099915], "intensity": [15.192, 0.866, 25.518, 1.066, 0.765, 1.415, 1.948, 100.0, 1.505, 1.92, 1.017, 1.199, 0.691, 0.428, 0.814, 0.722, 0.466, 0.481, 0.427, 0.422, 0.45, 0.573, 1.322, 0.539, 2.06, 0.624, 0.688], "annotation": []}, "summary": {"name": "AKOS033801958", "collision_energy": 60.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 349.12275, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [60.0], "molecular_weight": 350.13002818399997}} +{"SMILES": "COc1cc2c(cc1)[nH]cc2CC(=O)N1CCC(CS(C)(=O)=O)C1", "group_id": 79, "peaks": {"mz": [78.98526, 334.099101, 334.145355, 349.122397, 349.221771], "intensity": [2.445, 11.801, 0.366, 100.0, 0.568], "annotation": []}, "summary": {"name": "AKOS033801958", "collision_energy": 20.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 349.12275, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [20.0], "molecular_weight": 350.13002818399997}} +{"SMILES": "COc1cc2c(cc1)[nH]cc2CC(=O)N1CCC(CS(C)(=O)=O)C1", "group_id": 79, "peaks": {"mz": [63.962403, 78.985967, 78.990555, 93.001656, 131.038071, 133.032669, 144.018951, 144.045676, 145.053604, 171.032898, 173.538559, 173.545456, 241.098511, 254.106216, 255.114279, 291.080994, 333.091827, 334.099805, 349.122681], "intensity": [12.807, 44.861, 0.799, 2.309, 0.927, 0.702, 0.605, 100.0, 1.581, 0.754, 2.173, 0.782, 3.057, 2.556, 7.63, 1.235, 4.015, 29.302, 1.073], "annotation": []}, "summary": {"name": "AKOS033801958", "collision_energy": 45.0, "instrument": "HCD", "ionization": "ESI", "precursor_mz": 349.12275, "precursor_mode": "[M-H]-", "retention_time": NaN, "ce_steps": [45.0], "molecular_weight": 350.13002818399997}} diff --git a/tests/test_mol_core.py b/tests/test_mol_core.py new file mode 100644 index 0000000..d697ae9 --- /dev/null +++ b/tests/test_mol_core.py @@ -0,0 +1,240 @@ +import json +from pathlib import Path + +import pytest +import torch +from rdkit import RDLogger + +from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder +from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder +from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder +from fiora.MOL.Metabolite import Metabolite +from fiora.MOL.MetaboliteIndex import MetaboliteIndex +from fiora.MOL.collision_energy import NCE_to_eV, align_CE, nce_instruments + + +SAMPLE_SIZE = 100 +SPECTRA_FIXTURE_PATH = ( + Path(__file__).resolve().parent / "data" / "mol_core_100_spectra.jsonl" +) + +RDLogger.DisableLog("rdApp.*") + + +@pytest.fixture(scope="module") +def sample_spectra(): + if not SPECTRA_FIXTURE_PATH.exists(): + pytest.skip(f"Missing test fixture: {SPECTRA_FIXTURE_PATH}") + + records = [] + with SPECTRA_FIXTURE_PATH.open("r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + records.append(json.loads(line)) + + if len(records) != SAMPLE_SIZE: + pytest.skip( + f"Expected {SAMPLE_SIZE} records in fixture, found {len(records)}" + ) + return records + + +@pytest.fixture(scope="module") +def prepared_metabolites(sample_spectra): + node_encoder = AtomFeatureEncoder( + feature_list=["symbol", "num_hydrogen", "ring_type"] + ) + bond_encoder = BondFeatureEncoder(feature_list=["bond_type", "ring_type"]) + setup_encoder = CovariateFeatureEncoder( + feature_list=[ + "collision_energy", + "molecular_weight", + "precursor_mode", + "instrument", + "element_composition", + ] + ) + rt_encoder = CovariateFeatureEncoder( + feature_list=[ + "molecular_weight", + "precursor_mode", + "instrument", + "element_composition", + ] + ) + + prepared = [] + for row in sample_spectra: + metabolite = Metabolite(row["SMILES"]) + metabolite.create_molecular_structure_graph() + metabolite.compute_graph_attributes( + node_encoder=node_encoder, bond_encoder=bond_encoder + ) + metabolite.add_metadata(dict(row["summary"]), setup_encoder, rt_encoder) + prepared.append({"metabolite": metabolite, "peaks": row["peaks"]}) + return prepared + + +@pytest.fixture(scope="module") +def indexed_metabolites(prepared_metabolites): + metabolites = [entry["metabolite"] for entry in prepared_metabolites] + index = MetaboliteIndex() + index.index_metabolites(metabolites) + index.create_fragmentation_trees(depth=1) + mismatches = index.add_fragmentation_trees_to_metabolite_list(metabolites) + return {"index": index, "metabolites": metabolites, "mismatches": mismatches} + + +def test_load_100_spectra_fixture(sample_spectra): + assert len(sample_spectra) == SAMPLE_SIZE + first = sample_spectra[0] + assert {"SMILES", "peaks", "summary", "group_id"}.issubset(first.keys()) + assert len(first["peaks"]["mz"]) == len(first["peaks"]["intensity"]) + assert len(first["peaks"]["mz"]) > 0 + assert "collision_energy" in first["summary"] + assert "precursor_mode" in first["summary"] + + +def test_metabolite_graph_building(prepared_metabolites): + assert len(prepared_metabolites) == SAMPLE_SIZE + + for entry in prepared_metabolites: + metabolite = entry["metabolite"] + assert metabolite.edges.shape[1] == 2 + assert len(metabolite.edges_as_tuples) == metabolite.edges.shape[0] + assert metabolite.node_features.shape[0] == metabolite.Graph.number_of_nodes() + assert metabolite.bond_features.shape[0] == len(metabolite.edges_as_tuples) + assert metabolite.setup_features.shape[0] == 1 + assert metabolite.setup_features_per_edge.shape[0] == len( + metabolite.edges_as_tuples + ) + + +def test_metabolite_index_and_fragmentation_trees(indexed_metabolites): + index = indexed_metabolites["index"] + metabolites = indexed_metabolites["metabolites"] + mismatches = indexed_metabolites["mismatches"] + + assert len(mismatches) == 0 + assert 0 < index.get_number_of_metabolites() <= SAMPLE_SIZE + + for metabolite in metabolites: + assert metabolite.fragmentation_tree is not None + assert metabolite.subgraph_elem_comp.shape[0] == metabolite.edges.shape[0] + assert metabolite.subgraph_idx_left.shape == metabolite.subgraph_idx_right.shape + + +def test_peak_matching_and_geometric_export(prepared_metabolites, indexed_metabolites): + _ = indexed_metabolites # Ensure fragmentation trees are attached. + matched = 0 + + for entry in prepared_metabolites: + metabolite = entry["metabolite"] + peaks = entry["peaks"] + mz_list = peaks["mz"] + int_list = peaks["intensity"] + + metabolite.match_fragments_to_peaks(mz_list, int_list) + geom = metabolite.as_geometric_data() + + assert metabolite.match_stats["num_peaks"] == len(mz_list) + assert ( + geom.compiled_probsALL.shape[0] == metabolite.edge_count_matrix.numel() + 2 + ) + assert ( + geom.compiled_validation_maskALL.shape[0] == geom.compiled_probsALL.shape[0] + ) + assert torch.isfinite(geom.compiled_probsSQRT).all() + matched += 1 + + assert matched == SAMPLE_SIZE + + +def test_edge_count_cols_helper(): + mode_map = {"[M-H]-": 0, "[M+H]+": 1} + mode_count = len(mode_map) + + left_forward, left_backward = Metabolite._edge_count_cols( + mode_map, mode_count, "[M+H]+", "left" + ) + right_forward, right_backward = Metabolite._edge_count_cols( + mode_map, mode_count, "[M+H]+", "right" + ) + + assert (left_forward, left_backward) == (1, 3) + assert (right_forward, right_backward) == (3, 1) + + +def test_collision_energy_helpers(): + assert NCE_to_eV(20.0, 250.0) == pytest.approx(10.0) + assert align_CE("35eV", 200.0) == pytest.approx(35.0) + assert align_CE("2keV", 200.0) == pytest.approx(2000.0) + assert align_CE(20.0, 250.0, instrument=nce_instruments[0]) == pytest.approx(10.0) + assert align_CE("15% (nominal)", 300.0) == pytest.approx(NCE_to_eV(15.0, 300.0)) + + +def test_edge_count_matrix_accumulates_repeated_matches(): + class _StubFragment: + def __init__(self, edge, break_side): + self.edges = [edge] + self.break_sides = [break_side] + + def num_of_edges(self): + return 1 + + class _StubTree: + def __init__(self, peak_matches): + self.peak_matches = peak_matches + + def match_peak_list(self, mz_list, int_list, tolerance=None): + return self.peak_matches + + mode_map = {"[M+H]+": 0} + edge = (0, 1) + + peak_matches = { + 100.0: { + "intensity": 10.0, + "relative_intensity": 10.0 / 15.0, + "fragments": [_StubFragment(edge=edge, break_side="left")], + "ion_modes": [("[M+H]+", 100.0)], + }, + 101.0: { + "intensity": 5.0, + "relative_intensity": 5.0 / 15.0, + "fragments": [_StubFragment(edge=edge, break_side="left")], + "ion_modes": [("[M+H]+", 101.0)], + }, + } + + metabolite = Metabolite("CC") + metabolite.create_molecular_structure_graph() + metabolite.compute_graph_attributes() + metabolite.fragmentation_tree = _StubTree(peak_matches) + + metabolite.match_fragments_to_peaks( + mz_fragments=[100.0, 101.0], + int_list=[10.0, 5.0], + mode_map_override=mode_map, + ) + + forward_col, backward_col = Metabolite._edge_count_cols( + mode_map, len(mode_map), "[M+H]+", "left" + ) + forward_idx = ( + ((torch.tensor(edge) == metabolite.edges).sum(dim=1) == 2).nonzero().squeeze() + ) + backward_idx = ( + ((torch.tensor(edge[::-1]) == metabolite.edges).sum(dim=1) == 2) + .nonzero() + .squeeze() + ) + + assert metabolite.edge_count_matrix[ + forward_idx, forward_col + ].item() == pytest.approx(15.0) + assert metabolite.edge_count_matrix[ + backward_idx, backward_col + ].item() == pytest.approx(15.0) From 40bbb4a625a4ed480bb668e8967a87983299ede0 Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Sun, 19 Apr 2026 09:23:11 +0200 Subject: [PATCH 10/15] Add Ruff CI and pre-commit to fiora --- .github/workflows/actions.yml | 66 ++++ .pre-commit-config.yaml | 7 + fiora/GNN/AtomFeatureEncoder.py | 106 ++--- fiora/GNN/BondFeatureEncoder.py | 63 +-- fiora/GNN/CovariateFeatureEncoder.py | 57 +-- fiora/GNN/Datasets.py | 54 +-- fiora/GNN/EdgePropertyPredictor.py | 20 +- fiora/GNN/FeatureEmbedding.py | 33 +- fiora/GNN/FioraModel.py | 202 +++++----- fiora/GNN/GNN.py | 73 ++-- fiora/GNN/GraphPropertyPredictor.py | 14 +- fiora/GNN/Losses.py | 32 +- fiora/GNN/Trainer.py | 79 ++-- fiora/GNN/fabric_training.py | 198 +++++----- fiora/IO/LibraryLoader.py | 4 +- fiora/IO/cfmReader.py | 54 +-- fiora/IO/fraggraphReader.py | 20 +- fiora/IO/mgfReader.py | 42 +- fiora/IO/mgfWriter.py | 20 +- fiora/IO/molReader.py | 3 +- fiora/IO/mspReader.py | 72 ++-- fiora/IO/mspWriter.py | 34 +- fiora/IO/mspredReader.py | 39 +- fiora/IO/mspredWriter.py | 95 ++--- fiora/MOL/FragmentationTree.py | 63 ++- fiora/MOL/Metabolite.py | 258 ++++++------- fiora/MOL/MetaboliteDatasetStatistics.py | 48 +-- fiora/MOL/MetaboliteIndex.py | 27 +- fiora/MOL/collision_energy.py | 46 +-- fiora/MOL/constants.py | 76 ++-- fiora/MOL/mol_graph.py | 25 +- fiora/MS/SimulationFramework.py | 197 +++++----- fiora/MS/ms_utility.py | 66 ++-- fiora/MS/spectral_scores.py | 43 +-- fiora/cli/eval.py | 232 +++++------ fiora/cli/model_info.py | 90 ++--- fiora/cli/predict.py | 278 ++++++------- fiora/cli/train.py | 428 ++++++++++----------- fiora/visualization/define_colors.py | 190 ++++----- fiora/visualization/inspect_mgf_file.py | 37 +- fiora/visualization/plot_spectrum.py | 35 +- fiora/visualization/spectrum_visualizer.py | 69 ++-- pyproject.toml | 88 ++++- tests/test_fiora_eval.py | 32 +- tests/test_fiora_model_info.py | 32 +- tests/test_fiora_predict.py | 70 ++-- tests/test_mol_core.py | 115 +++--- tests/test_trainer_history.py | 14 +- 48 files changed, 2058 insertions(+), 1888 deletions(-) create mode 100644 .github/workflows/actions.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml new file mode 100644 index 0000000..0718046 --- /dev/null +++ b/.github/workflows/actions.yml @@ -0,0 +1,66 @@ +name: Installing and Testing +on: [push] + +permissions: + contents: write + checks: write + pull-requests: write + +env: + UV_SYSTEM_PYTHON: true + +jobs: + install-and-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Install dependencies + run: | + pip install --upgrade pip + uv pip install -e '.[dev]' + # - name: Test with pytest + # run: | + # python -m pytest -sv tests + build-and-install: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + - name: Build the package + run: | + uv pip install build + python -m build --sdist + - name: Install the package + run: | + uv pip install dist/*.tar.gz + ruff-linting: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v3 + with: + version: "0.15.11" + args: "check" + src: "." + ruff-formatting: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v3 + with: + version: "0.15.11" + args: "format --check --verbose" + src: "." diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6e71d57 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.11 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/fiora/GNN/AtomFeatureEncoder.py b/fiora/GNN/AtomFeatureEncoder.py index 77daf10..8cc271e 100644 --- a/fiora/GNN/AtomFeatureEncoder.py +++ b/fiora/GNN/AtomFeatureEncoder.py @@ -1,25 +1,27 @@ +from typing import Literal + import torch from rdkit import Chem -from typing import Literal + from fiora.MOL.constants import ORDERED_ELEMENT_LIST class AtomFeatureEncoder: - def __init__(self, feature_list=["symbol", "num_hydrogen", "ring_type"]): + def __init__(self, feature_list=['symbol', 'num_hydrogen', 'ring_type']): self.encoding_dim = 0 self.sets = { - "symbol": ORDERED_ELEMENT_LIST, # OTHERS: Au, Se, Si #standard list {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"}, - "num_hydrogen": [0, 1, 2, 3], # OTHERS: 5, 6, 7, 8}, - "ring_type": ["no-ring", "small-ring", "5-cycle", "6-cycle", "large-ring"], - "hybridization": ["SP", "SP2", "SP3", "SP3D2"], - "valence_electrons": [1, 2, 3, 4, 5, 6, 7, 8], - "oxidation_number": [1, 2, 3, 4, 5, 6, 7, 8, 9], + 'symbol': ORDERED_ELEMENT_LIST, # OTHERS: Au, Se, Si #standard list {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"}, + 'num_hydrogen': [0, 1, 2, 3], # OTHERS: 5, 6, 7, 8}, + 'ring_type': ['no-ring', 'small-ring', '5-cycle', '6-cycle', 'large-ring'], + 'hybridization': ['SP', 'SP2', 'SP3', 'SP3D2'], + 'valence_electrons': [1, 2, 3, 4, 5, 6, 7, 8], + 'oxidation_number': [1, 2, 3, 4, 5, 6, 7, 8, 9], } self.feature_list = feature_list self.reduced_features = [ - "symbol", - "num_hydrogen", - "hybridization", + 'symbol', + 'num_hydrogen', + 'hybridization', ] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS) self.one_hot_mapper = {} @@ -41,62 +43,62 @@ def __init__(self, feature_list=["symbol", "num_hydrogen", "ring_type"]): num_variables += 1 self.feature_numbers[feature] = num_variables - def encode(self, G, encoder_type: Literal["one_hot", "number"]): + def encode(self, G, encoder_type: Literal['one_hot', 'number']): - if encoder_type == "one_hot": + if encoder_type == 'one_hot': feature_matrix = torch.zeros( G.number_of_nodes(), self.encoding_dim, dtype=torch.float32 ) for i in range(G.number_of_nodes()): - atom = G.nodes()[i]["atom"] + atom = G.nodes()[i]['atom'] - if "symbol" in self.feature_list: - if atom.GetSymbol() not in self.sets["symbol"]: + if 'symbol' in self.feature_list: + if atom.GetSymbol() not in self.sets['symbol']: feature_matrix[i][ - self.one_hot_mapper["symbol"][list(self.sets["symbol"])[-1]] + self.one_hot_mapper['symbol'][list(self.sets['symbol'])[-1]] + 1 ] = 1.0 else: feature_matrix[i][ - self.one_hot_mapper["symbol"][atom.GetSymbol()] + self.one_hot_mapper['symbol'][atom.GetSymbol()] ] = 1.0 - if "num_hydrogen" in self.feature_list: + if 'num_hydrogen' in self.feature_list: value = atom.GetTotalNumHs() - if value in self.sets["num_hydrogen"]: + if value in self.sets['num_hydrogen']: feature_matrix[i][ - self.one_hot_mapper["num_hydrogen"][atom.GetTotalNumHs()] + self.one_hot_mapper['num_hydrogen'][atom.GetTotalNumHs()] ] = 1.0 else: feature_matrix[i][ - self.one_hot_mapper["num_hydrogen"][ - list(self.sets["num_hydrogen"])[-1] + self.one_hot_mapper['num_hydrogen'][ + list(self.sets['num_hydrogen'])[-1] ] + 1 ] = 1.0 - if "ring_type" in self.feature_list: + if 'ring_type' in self.feature_list: if not atom.IsInRing(): - ring_type = "no-ring" + ring_type = 'no-ring' elif atom.IsInRingSize(7): - ring_type = "large-ring" + ring_type = 'large-ring' elif atom.IsInRingSize(6): - ring_type = "6-cycle" + ring_type = '6-cycle' elif atom.IsInRingSize(5): - ring_type = "5-cycle" + ring_type = '5-cycle' else: - ring_type = "small-ring" - feature_matrix[i][self.one_hot_mapper["ring_type"][ring_type]] = 1.0 - if "hybridization" in self.feature_list: + ring_type = 'small-ring' + feature_matrix[i][self.one_hot_mapper['ring_type'][ring_type]] = 1.0 + if 'hybridization' in self.feature_list: orbi = atom.GetHybridization().name - if orbi in self.sets["hybridization"]: + if orbi in self.sets['hybridization']: feature_matrix[i][ - self.one_hot_mapper["hybridization"][orbi] + self.one_hot_mapper['hybridization'][orbi] ] = 1.0 else: feature_matrix[i][ - self.one_hot_mapper["hybridization"][ - list(self.sets["hybridization"])[-1] + self.one_hot_mapper['hybridization'][ + list(self.sets['hybridization'])[-1] ] + 1 ] = 1.0 @@ -106,51 +108,51 @@ def encode(self, G, encoder_type: Literal["one_hot", "number"]): G.number_of_nodes(), len(self.feature_list), dtype=torch.int ) for i in range(G.number_of_nodes()): - atom = G.nodes()[i]["atom"] + atom = G.nodes()[i]['atom'] for j, feature in enumerate(self.feature_list): - if feature == "symbol": - if atom.GetSymbol() in self.sets["symbol"]: + if feature == 'symbol': + if atom.GetSymbol() in self.sets['symbol']: feature_matrix[i][j] = self.number_mapper[feature][ atom.GetSymbol() ] else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 - elif feature == "num_hydrogen": + elif feature == 'num_hydrogen': value = atom.GetTotalNumHs() - if value in self.sets["num_hydrogen"]: + if value in self.sets['num_hydrogen']: feature_matrix[i][j] = self.number_mapper[feature][value] else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 - elif feature == "valence_electrons": + elif feature == 'valence_electrons': value = atom.GetExplicitValence() - if value in self.sets["valence_electrons"]: + if value in self.sets['valence_electrons']: feature_matrix[i][j] = self.number_mapper[feature][value] else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 - elif feature == "oxidation_number": + elif feature == 'oxidation_number': raise NotImplementedError() value = Chem.rdMolDescriptors.CalcOxidationNumbers(atom) - if value in self.sets["oxidation_number"]: + if value in self.sets['oxidation_number']: feature_matrix[i][j] = self.number_mapper[feature][value] else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 - elif feature == "ring_type": + elif feature == 'ring_type': if not atom.IsInRing(): - ring_type = "no-ring" + ring_type = 'no-ring' elif atom.IsInRingSize(7): - ring_type = "large-ring" + ring_type = 'large-ring' elif atom.IsInRingSize(6): - ring_type = "6-cycle" + ring_type = '6-cycle' elif atom.IsInRingSize(5): - ring_type = "5-cycle" + ring_type = '5-cycle' else: - ring_type = "small-ring" + ring_type = 'small-ring' feature_matrix[i][j] = self.number_mapper[feature][ring_type] - if feature == "hybridization": + if feature == 'hybridization': orbi = atom.GetHybridization().name - if orbi in self.sets["hybridization"]: + if orbi in self.sets['hybridization']: feature_matrix[i][j] = self.number_mapper[feature][orbi] else: feature_matrix[i][j] = self.feature_numbers[feature] - 1 diff --git a/fiora/GNN/BondFeatureEncoder.py b/fiora/GNN/BondFeatureEncoder.py index 05aaf60..ed7a3b5 100644 --- a/fiora/GNN/BondFeatureEncoder.py +++ b/fiora/GNN/BondFeatureEncoder.py @@ -1,15 +1,16 @@ -import torch from typing import Literal +import torch + class BondFeatureEncoder: - def __init__(self, feature_list=["bond_type", "ring_type"]): + def __init__(self, feature_list=['bond_type', 'ring_type']): self.encoding_dim = 0 self.feature_list = feature_list self.sets = { - "bond_type": ["AROMATIC", "SINGLE", "DOUBLE", "TRIPLE"], - "ring_type": ["no-ring", "small-ring", "5-cycle", "6-cycle", "large-ring"], - "ring_type_binary": ["is_in_ring"], + 'bond_type': ['AROMATIC', 'SINGLE', 'DOUBLE', 'TRIPLE'], + 'ring_type': ['no-ring', 'small-ring', '5-cycle', '6-cycle', 'large-ring'], + 'ring_type_binary': ['is_in_ring'], } self.reduced_features = [] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS) self.one_hot_mapper = {} @@ -32,67 +33,67 @@ def __init__(self, feature_list=["bond_type", "ring_type"]): num_variables += 1 self.feature_numbers[feature] = num_variables - def encode(self, G, edges, encoder_type: Literal["one_hot", "number"]): + def encode(self, G, edges, encoder_type: Literal['one_hot', 'number']): - if encoder_type == "one_hot": + if encoder_type == 'one_hot': feature_matrix = torch.zeros( len(edges), self.encoding_dim, dtype=torch.float32 ) for i, (u, v) in enumerate(edges): - bond = G[u][v]["bond"] - if "bond_type" in self.feature_list: + bond = G[u][v]['bond'] + if 'bond_type' in self.feature_list: feature_matrix[i][ - self.one_hot_mapper["bond_type"][bond.GetBondType().name] + self.one_hot_mapper['bond_type'][bond.GetBondType().name] ] = 1.0 - if "ring_type" in self.feature_list: + if 'ring_type' in self.feature_list: if not bond.IsInRing(): - ring_type = "no-ring" + ring_type = 'no-ring' elif bond.IsInRingSize(7): - ring_type = "large-ring" + ring_type = 'large-ring' elif bond.IsInRingSize(6): - ring_type = "6-cycle" + ring_type = '6-cycle' elif bond.IsInRingSize(5): - ring_type = "5-cycle" + ring_type = '5-cycle' else: - ring_type = "small-ring" - feature_matrix[i][self.one_hot_mapper["ring_type"][ring_type]] = 1.0 - if "ring_type_binary" in self.feature_list: + ring_type = 'small-ring' + feature_matrix[i][self.one_hot_mapper['ring_type'][ring_type]] = 1.0 + if 'ring_type_binary' in self.feature_list: if bond.IsInRing(): feature_matrix[i][ - self.one_hot_mapper["ring_type_binary"]["is_in_ring"] + self.one_hot_mapper['ring_type_binary']['is_in_ring'] ] = 1.0 # else case implicit = 0 - elif encoder_type == "number": # Case: Number mapping + elif encoder_type == 'number': # Case: Number mapping feature_matrix = torch.zeros( len(edges), len(self.feature_list), dtype=torch.int ) for i, (u, v) in enumerate(edges): - bond = G[u][v]["bond"] + bond = G[u][v]['bond'] for j, feature in enumerate(self.feature_list): - if feature == "bond_type": + if feature == 'bond_type': value = bond.GetBondType().name - if value in self.sets["bond_type"]: + if value in self.sets['bond_type']: feature_matrix[i][j] = self.number_mapper[feature][value] else: raise NotImplementedError( - "Unknown bond type is not accounted for." + 'Unknown bond type is not accounted for.' ) - elif feature == "ring_type": + elif feature == 'ring_type': if not bond.IsInRing(): - ring_type = "no-ring" + ring_type = 'no-ring' elif bond.IsInRingSize(7): - ring_type = "large-ring" + ring_type = 'large-ring' elif bond.IsInRingSize(6): - ring_type = "6-cycle" + ring_type = '6-cycle' elif bond.IsInRingSize(5): - ring_type = "5-cycle" + ring_type = '5-cycle' else: - ring_type = "small-ring" + ring_type = 'small-ring' feature_matrix[i][j] = self.number_mapper[feature][ring_type] - if feature == "ring_type_binary": + if feature == 'ring_type_binary': raise NotImplementedError( "Binary feature not implemented with number embedding. Use default 'ring_type' instead." ) diff --git a/fiora/GNN/CovariateFeatureEncoder.py b/fiora/GNN/CovariateFeatureEncoder.py index b2890b7..cc92f1f 100644 --- a/fiora/GNN/CovariateFeatureEncoder.py +++ b/fiora/GNN/CovariateFeatureEncoder.py @@ -1,4 +1,5 @@ import torch + from fiora.MOL.constants import ORDERED_ELEMENT_LIST_WITH_HYDROGEN @@ -6,41 +7,41 @@ class CovariateFeatureEncoder: def __init__( self, feature_list=[ - "collision_energy", - "molecular_weight", - "precursor_mode", - "instrument", - "element_composition", + 'collision_energy', + 'molecular_weight', + 'precursor_mode', + 'instrument', + 'element_composition', ], sets_overwrite: dict | None = None, ): - if "ce_steps" in feature_list: + if 'ce_steps' in feature_list: raise ValueError( "'ce_steps' is not meant as a setup feature. Remove from feature_list" ) self.encoding_dim = 0 self.feature_list = feature_list self.categorical_sets = { - "instrument": [ - "HCD", - "Q-TOF", - "IT-FT/ion trap with FTMS", - "IT/ion trap", + 'instrument': [ + 'HCD', + 'Q-TOF', + 'IT-FT/ion trap with FTMS', + 'IT/ion trap', ], # "IT-FT/ion trap with FTMS", "IT/ion trap", "QqQ", "QqQ/triple quadrupole" - "precursor_mode": ["[M+H]+", "[M-H]-"], + 'precursor_mode': ['[M+H]+', '[M-H]-'], } if sets_overwrite: for new_set, new_categories in sets_overwrite.items(): self.categorical_sets[new_set] = new_categories - self.continuous_set = {"collision_energy", "molecular_weight"} + self.continuous_set = {'collision_energy', 'molecular_weight'} self.normalize_features = { - "collision_energy": {"min": 0, "max": 100, "transform": "linear"}, - "molecular_weight": {"min": 0, "max": 1000, "transform": "linear"}, + 'collision_energy': {'min': 0, 'max': 100, 'transform': 'linear'}, + 'molecular_weight': {'min': 0, 'max': 1000, 'transform': 'linear'}, } self.reduced_categorical_features = [ - "instrument" + 'instrument' ] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS) self.one_hot_mapper = {} for feature in self.feature_list: @@ -59,8 +60,8 @@ def __init__( self.one_hot_mapper[feature] = self.encoding_dim self.encoding_dim += 1 - if "element_composition" in self.feature_list: - self.one_hot_mapper["element_composition"] = { + if 'element_composition' in self.feature_list: + self.one_hot_mapper['element_composition'] = { element: idx for idx, element in enumerate( ORDERED_ELEMENT_LIST_WITH_HYDROGEN, start=self.encoding_dim @@ -87,14 +88,14 @@ def encode(self, dim0, metadata, G=None): elif feature in self.continuous_set: value = metadata[feature] if feature in self.normalize_features.keys(): - value = (value - self.normalize_features[feature]["min"]) / ( - self.normalize_features[feature]["max"] - - self.normalize_features[feature]["min"] + value = (value - self.normalize_features[feature]['min']) / ( + self.normalize_features[feature]['max'] + - self.normalize_features[feature]['min'] ) feature_matrix[:, self.one_hot_mapper[feature]] = value feature_matrix = torch.clamp(feature_matrix, 0.0, 1.0) - elif feature == "element_composition": + elif feature == 'element_composition': if G is None: raise ValueError( "Graph G must be provided to encode 'element_composition'" @@ -102,16 +103,16 @@ def encode(self, dim0, metadata, G=None): element_composition = self.get_element_composition(G) for idx, element in enumerate(ORDERED_ELEMENT_LIST_WITH_HYDROGEN): feature_matrix[ - :, self.one_hot_mapper["element_composition"][element] + :, self.one_hot_mapper['element_composition'][element] ] = element_composition[idx] return feature_matrix def normalize_collision_steps(self, ce_steps): norm_ce = lambda x: ( - (x - self.normalize_features["collision_energy"]["min"]) + (x - self.normalize_features['collision_energy']['min']) / ( - self.normalize_features["collision_energy"]["max"] - - self.normalize_features["collision_energy"]["min"] + self.normalize_features['collision_energy']['max'] + - self.normalize_features['collision_energy']['min'] ) ) ce_steps = [norm_ce(x) for x in ce_steps] @@ -126,7 +127,7 @@ def get_element_composition(G): # Iterate through nodes in the graph for node in G.nodes: - atom = G.nodes[node]["atom"] + atom = G.nodes[node]['atom'] symbol = atom.GetSymbol() # Get the atomic symbol if symbol in ORDERED_ELEMENT_LIST_WITH_HYDROGEN: index = ORDERED_ELEMENT_LIST_WITH_HYDROGEN.index( @@ -137,7 +138,7 @@ def get_element_composition(G): # Add hydrogens explicitly hydrogens = atom.GetTotalNumHs() hydrogen_index = ORDERED_ELEMENT_LIST_WITH_HYDROGEN.index( - "H" + 'H' ) # Ensure 'H' is in ORDERED_ELEMENT_LIST element_composition[hydrogen_index] += hydrogens diff --git a/fiora/GNN/Datasets.py b/fiora/GNN/Datasets.py index 0fe23b3..0b1f72a 100644 --- a/fiora/GNN/Datasets.py +++ b/fiora/GNN/Datasets.py @@ -1,6 +1,6 @@ -import torch -import pandas as pd import numpy as np +import pandas as pd +import torch from torch.utils.data import Dataset """ @@ -24,13 +24,13 @@ class AtomAromaticityData(Dataset): def __init__(self, df) -> None: self.X = ( - df["features"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values + df['features'].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values ) self.A = ( - df["Atilde"].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values + df['Atilde'].apply(lambda x: torch.tensor(x, dtype=torch.float32)).values ) self.y = ( - df["is_aromatic"] + df['is_aromatic'] .apply(lambda x: torch.tensor(x, dtype=torch.float32)) .values * 1 @@ -48,7 +48,7 @@ def num_features(self): class SimpleNodeData(Dataset): def __init__( - self, data: pd.Series, feature_tag: str, label: str, device="cpu" + self, data: pd.Series, feature_tag: str, label: str, device='cpu' ) -> None: self.X = torch.cat( data.apply(lambda x: getattr(x, feature_tag).to(device)).to_list() @@ -67,7 +67,7 @@ def num_features(self): class NodeSingleLabelData(Dataset): def __init__( - self, data: pd.Series, feature_tag: str, adj_tag: str, label: str, device="cpu" + self, data: pd.Series, feature_tag: str, adj_tag: str, label: str, device='cpu' ) -> None: self.X = data.apply(lambda x: getattr(x, feature_tag).to(device)).values @@ -98,7 +98,7 @@ def __init__( label: str, validation_mask_tag: str, group_id: str, - device="cpu", + device='cpu', ) -> None: self.X = data.apply(lambda x: getattr(x, feature_tag).to(device)).values @@ -189,12 +189,12 @@ def collate_graph_batch(batch): adj_mask[i, : num_nodes[i], : num_nodes[i]] = 1 batch_record = { - "X": X, - "A": A, - "y": y, - "node_mask": node_mask, - "adj_mask": adj_mask, - "num_of_nodes": torch.tensor(list(num_nodes)).unsqueeze(dim=1), + 'X': X, + 'A': A, + 'y': y, + 'node_mask': node_mask, + 'adj_mask': adj_mask, + 'num_of_nodes': torch.tensor(list(num_nodes)).unsqueeze(dim=1), } return batch_record @@ -254,18 +254,18 @@ def pad_matrix(x, y, z): validation_mask[i, : num_edges[i]] = validation_bits[i].flatten() batch_record = { - "X": X, - "A": A, - "y": y, - "AL": AL, - "AR": AR, - "node_mask": node_mask, - "adj_mask": adj_mask, - "y_mask": y_mask, - "edge_features": edge_features, - "static_features": static_features, - "validation_mask": validation_mask, - "num_of_nodes": torch.tensor(list(num_nodes)).unsqueeze(dim=1), - "num_of_edges": torch.tensor(list(num_edges)).unsqueeze(dim=1), + 'X': X, + 'A': A, + 'y': y, + 'AL': AL, + 'AR': AR, + 'node_mask': node_mask, + 'adj_mask': adj_mask, + 'y_mask': y_mask, + 'edge_features': edge_features, + 'static_features': static_features, + 'validation_mask': validation_mask, + 'num_of_nodes': torch.tensor(list(num_nodes)).unsqueeze(dim=1), + 'num_of_edges': torch.tensor(list(num_edges)).unsqueeze(dim=1), } return batch_record diff --git a/fiora/GNN/EdgePropertyPredictor.py b/fiora/GNN/EdgePropertyPredictor.py index d2dc846..624f3fc 100644 --- a/fiora/GNN/EdgePropertyPredictor.py +++ b/fiora/GNN/EdgePropertyPredictor.py @@ -1,7 +1,7 @@ +from typing import Dict, Literal + import torch -from typing import Dict import torch_geometric.nn as geom_nn -from typing import Literal from fiora.MOL.constants import ORDERED_ELEMENT_LIST_WITH_HYDROGEN @@ -16,10 +16,10 @@ def __init__( dense_depth: int = 0, dense_dim: int = None, embedding_dim: int = 200, - embedding_aggregation_type: str = "concat", + embedding_aggregation_type: str = 'concat', residual_connections: bool = False, subgraph_features: bool = False, - pooling_func: Literal["avg", "max"] = "avg", + pooling_func: Literal['avg', 'max'] = 'avg', input_dropout: float = 0, latent_dropout: float = 0, ) -> None: @@ -47,7 +47,7 @@ def __init__( self.subgraph_features = subgraph_features self.pooling_func = ( geom_nn.global_mean_pool - if pooling_func == "avg" + if pooling_func == 'avg' else geom_nn.global_max_pool ) num_subgraph_features = ( @@ -66,7 +66,7 @@ def __init__( hidden_dimension = dense_dim if dense_dim is not None else num_features if hidden_dimension != num_features and residual_connections: raise NotImplementedError( - "Residual connections require the hidden dimension to match the input dimension." + 'Residual connections require the hidden dimension to match the input dimension.' ) for _ in range(dense_depth): dense_layers += [torch.nn.Linear(num_features, hidden_dimension)] @@ -81,7 +81,7 @@ def concat_node_pairs(self, X, batch): This version includes debug messages to trace tensor shapes. """ - src, dst = batch["edge_index"] + src, dst = batch['edge_index'] # 1. Node Pair Concatenation X_src = X[src] @@ -129,7 +129,7 @@ def concat_node_pairs(self, X, batch): subgraph_features = torch.cat([pooled_left, pooled_right], dim=1) # 3. Final Concatenation with Subgraph Features - edge_elem_comp = batch["edge_elem_comp"] + edge_elem_comp = batch['edge_elem_comp'] # This is the likely point of failure node_pairs = torch.cat( @@ -144,11 +144,11 @@ def forward(self, X, batch): # Add edge features and static features edge_features = batch[ - "edge_embedding" + 'edge_embedding' ] # self.edge_embedding(batch["edge_attr"]) edge_features = self.input_dropout(edge_features) X = torch.cat( - [X, edge_features, batch["static_edge_features"]], axis=-1 + [X, edge_features, batch['static_edge_features']], axis=-1 ) # self.input_dropout(batch["static_edge_features"]) # Apply fully connected layers diff --git a/fiora/GNN/FeatureEmbedding.py b/fiora/GNN/FeatureEmbedding.py index 7d804ae..0c4338a 100644 --- a/fiora/GNN/FeatureEmbedding.py +++ b/fiora/GNN/FeatureEmbedding.py @@ -1,6 +1,7 @@ -import torch -from typing import Dict, Literal import warnings +from typing import Dict, Literal + +import torch class FeatureEmbedding(torch.nn.Module): @@ -8,26 +9,26 @@ def __init__( self, feature_dict: Dict[str, int], dim=200, - aggregation_type=Literal["concat", "sum"], + aggregation_type=Literal['concat', 'sum'], ) -> None: super().__init__() self.aggregation_type = aggregation_type self.feature_dim = dim - if aggregation_type == "concat": + if aggregation_type == 'concat': num_features = len(feature_dict.keys()) self.feature_dim = int(dim / num_features) self.dim = self.feature_dim * num_features if self.dim != dim: warnings.warn( - f"Desired embedding dimension not cleanly dividable by the number of features. Reducing dimension from {dim} to {self.dim}." + f'Desired embedding dimension not cleanly dividable by the number of features. Reducing dimension from {dim} to {self.dim}.' ) - elif aggregation_type == "sum": + elif aggregation_type == 'sum': self.dim = dim self.feature_dim = dim else: raise NameError( - f"Unknown aggregation type selected. Valid types are {aggregation_type}." + f'Unknown aggregation type selected. Valid types are {aggregation_type}.' ) self.embeddings = torch.nn.ModuleList( [ @@ -46,9 +47,9 @@ def forward(self, features, feature_mask=None): values = features[:, i] node_embeddings.append(embedding(values)) - if self.aggregation_type == "sum": + if self.aggregation_type == 'sum': embedded_features = torch.sum(torch.stack(node_embeddings, dim=-1), dim=-1) - elif self.aggregation_type == "concat": + elif self.aggregation_type == 'concat': embedded_features = torch.cat(node_embeddings, dim=-1) if feature_mask is not None: @@ -62,26 +63,26 @@ def __init__( self, feature_dict: Dict[str, int], dim=200, - aggregation_type=Literal["concat", "sum"], + aggregation_type=Literal['concat', 'sum'], ) -> None: super().__init__() self.aggregation_type = aggregation_type self.feature_dim = dim - if aggregation_type == "concat": + if aggregation_type == 'concat': num_features = len(feature_dict.keys()) self.feature_dim = int(dim / num_features) self.dim = self.feature_dim * num_features if self.dim != dim: warnings.warn( - f"Desired embedding dimension not cleanly dividable by the number of features. Reducing dimension from {dim} to {self.dim}." + f'Desired embedding dimension not cleanly dividable by the number of features. Reducing dimension from {dim} to {self.dim}.' ) - elif aggregation_type == "sum": + elif aggregation_type == 'sum': self.dim = dim self.feature_dim = dim else: raise NameError( - f"Unknown aggregation type selected. Valid types are {aggregation_type}." + f'Unknown aggregation type selected. Valid types are {aggregation_type}.' ) self.embeddings = torch.nn.ModuleList( [ @@ -100,9 +101,9 @@ def forward(self, features, feature_mask=None): values = features[:, :, i] node_embeddings.append(embedding(values)) - if self.aggregation_type == "sum": + if self.aggregation_type == 'sum': embedded_features = torch.sum(torch.stack(node_embeddings, dim=-1), dim=-1) - elif self.aggregation_type == "concat": + elif self.aggregation_type == 'concat': embedded_features = torch.cat(node_embeddings, dim=-1) if feature_mask is not None: diff --git a/fiora/GNN/FioraModel.py b/fiora/GNN/FioraModel.py index 55e85cd..5ce2c33 100644 --- a/fiora/GNN/FioraModel.py +++ b/fiora/GNN/FioraModel.py @@ -1,15 +1,17 @@ +import json + +# Misc +from typing import Dict, Literal + +import dill import torch +from fiora.GNN.EdgePropertyPredictor import EdgePropertyPredictor + # Fiora GNN Modules from fiora.GNN.FeatureEmbedding import FeatureEmbedding from fiora.GNN.GNN import GNN from fiora.GNN.GraphPropertyPredictor import GraphPropertyPredictor -from fiora.GNN.EdgePropertyPredictor import EdgePropertyPredictor - -# Misc -from typing import Literal, Dict -import dill -import json class FioraModel(torch.nn.Module): @@ -22,81 +24,81 @@ def __init__(self, model_params: Dict) -> None: self._version_control_model_params(model_params) - self.edge_dim = model_params["output_dimension"] + self.edge_dim = model_params['output_dimension'] self.node_embedding = FeatureEmbedding( - feature_dict=model_params["node_feature_layout"], - dim=model_params["embedding_dimension"], - aggregation_type=model_params["embedding_aggregation"], + feature_dict=model_params['node_feature_layout'], + dim=model_params['embedding_dimension'], + aggregation_type=model_params['embedding_aggregation'], ) self.edge_embedding = FeatureEmbedding( - feature_dict=model_params["edge_feature_layout"], - dim=model_params["embedding_dimension"], - aggregation_type=model_params["embedding_aggregation"], + feature_dict=model_params['edge_feature_layout'], + dim=model_params['embedding_dimension'], + aggregation_type=model_params['embedding_aggregation'], ) self.GNN_module = GNN( - hidden_features=model_params["hidden_dimension"], - depth=model_params["depth"], + hidden_features=model_params['hidden_dimension'], + depth=model_params['depth'], embedding_dim=self.node_embedding.get_embedding_dimension(), - embedding_aggregation_type=model_params["embedding_aggregation"], - gnn_type=model_params["gnn_type"], - layer_norm=model_params["layer_norm"], - residual_connections=model_params["residual_connections"], - layer_stacking=model_params["layer_stacking"], - input_dropout=model_params["input_dropout"], - latent_dropout=model_params["latent_dropout"], + embedding_aggregation_type=model_params['embedding_aggregation'], + gnn_type=model_params['gnn_type'], + layer_norm=model_params['layer_norm'], + residual_connections=model_params['residual_connections'], + layer_stacking=model_params['layer_stacking'], + input_dropout=model_params['input_dropout'], + latent_dropout=model_params['latent_dropout'], ) self.edge_module = EdgePropertyPredictor( - edge_feature_dict=model_params["edge_feature_layout"], + edge_feature_dict=model_params['edge_feature_layout'], hidden_features=self.GNN_module.get_embedding_dimension(), - static_features=model_params["static_feature_dimension"], - out_dimension=model_params["output_dimension"], - dense_depth=model_params["dense_layers"], - dense_dim=model_params["dense_dim"], + static_features=model_params['static_feature_dimension'], + out_dimension=model_params['output_dimension'], + dense_depth=model_params['dense_layers'], + dense_dim=model_params['dense_dim'], embedding_dim=self.edge_embedding.get_embedding_dimension(), - embedding_aggregation_type=model_params["embedding_aggregation"], - residual_connections=model_params["residual_connections"], - subgraph_features=model_params["subgraph_features"], - pooling_func=model_params["pooling_func"], - input_dropout=model_params["input_dropout"], - latent_dropout=model_params["latent_dropout"], + embedding_aggregation_type=model_params['embedding_aggregation'], + residual_connections=model_params['residual_connections'], + subgraph_features=model_params['subgraph_features'], + pooling_func=model_params['pooling_func'], + input_dropout=model_params['input_dropout'], + latent_dropout=model_params['latent_dropout'], ) self.precursor_module = GraphPropertyPredictor( hidden_features=self.GNN_module.get_embedding_dimension(), - static_features=model_params["static_feature_dimension"], + static_features=model_params['static_feature_dimension'], out_dimension=1, - dense_depth=model_params["dense_layers"], - dense_dim=model_params["dense_dim"], - residual_connections=model_params["residual_connections"], - pooling_func=model_params["pooling_func"], - input_dropout=model_params["input_dropout"], - latent_dropout=model_params["latent_dropout"], + dense_depth=model_params['dense_layers'], + dense_dim=model_params['dense_dim'], + residual_connections=model_params['residual_connections'], + pooling_func=model_params['pooling_func'], + input_dropout=model_params['input_dropout'], + latent_dropout=model_params['latent_dropout'], ) - if model_params["prepare_additional_layers"]: + if model_params['prepare_additional_layers']: self.RT_module = GraphPropertyPredictor( hidden_features=self.GNN_module.get_embedding_dimension(), - static_features=model_params["static_rt_feature_dimension"], + static_features=model_params['static_rt_feature_dimension'], out_dimension=1, - dense_depth=model_params["dense_layers"], - dense_dim=model_params["dense_dim"], - residual_connections=model_params["residual_connections"], - pooling_func=model_params["pooling_func"], - input_dropout=model_params["input_dropout"], - latent_dropout=model_params["latent_dropout"], + dense_depth=model_params['dense_layers'], + dense_dim=model_params['dense_dim'], + residual_connections=model_params['residual_connections'], + pooling_func=model_params['pooling_func'], + input_dropout=model_params['input_dropout'], + latent_dropout=model_params['latent_dropout'], ) self.CCS_module = GraphPropertyPredictor( hidden_features=self.GNN_module.get_embedding_dimension(), - static_features=model_params["static_rt_feature_dimension"], + static_features=model_params['static_rt_feature_dimension'], out_dimension=1, - dense_depth=model_params["dense_layers"], - dense_dim=model_params["dense_dim"], - residual_connections=model_params["residual_connections"], - pooling_func=model_params["pooling_func"], - input_dropout=model_params["input_dropout"], - latent_dropout=model_params["latent_dropout"], + dense_depth=model_params['dense_layers'], + dense_dim=model_params['dense_dim'], + residual_connections=model_params['residual_connections'], + pooling_func=model_params['pooling_func'], + input_dropout=model_params['input_dropout'], + latent_dropout=model_params['latent_dropout'], ) - self.set_transform("double_softmax") + self.set_transform('double_softmax') self.model_params = model_params def _version_control_model_params(self, model_params: Dict) -> None: @@ -105,26 +107,26 @@ def _version_control_model_params(self, model_params: Dict) -> None: model_params (Dict): Dictionary containing model parameters. """ - if "residual_connections" not in model_params: - model_params["residual_connections"] = False - if "layer_stacking" not in model_params: - model_params["layer_stacking"] = False - if "prepare_additional_layers" not in model_params: - model_params["prepare_additional_layers"] = ( + if 'residual_connections' not in model_params: + model_params['residual_connections'] = False + if 'layer_stacking' not in model_params: + model_params['layer_stacking'] = False + if 'prepare_additional_layers' not in model_params: + model_params['prepare_additional_layers'] = ( True # Defaults to True, since old models have RT/CCS modules ) - if "dense_dim" not in model_params: - model_params["dense_dim"] = ( + if 'dense_dim' not in model_params: + model_params['dense_dim'] = ( None # None defaults to the number of input features (GNN output dimension) ) if ( - "subgraph_features" not in model_params + 'subgraph_features' not in model_params ): # No subgraph features in older models - model_params["subgraph_features"] = False - if "pooling" not in model_params: - model_params["pooling_func"] = "avg" - if "layer_norm" not in model_params: - model_params["layer_norm"] = False + model_params['subgraph_features'] = False + if 'pooling' not in model_params: + model_params['pooling_func'] = 'avg' + if 'layer_norm' not in model_params: + model_params['layer_norm'] = False return @@ -152,17 +154,17 @@ def set_dropout_rate(self, input_dropout: float, latent_dropout: float) -> None: self.CCS_module.latent_dropout.p = latent_dropout def set_transform( - self, transformation: Literal["softmax", "double_softmax", "off"] + self, transformation: Literal['softmax', 'double_softmax', 'off'] ): self.softmax = torch.nn.Softmax(dim=0) - if transformation == "double_softmax": + if transformation == 'double_softmax': self.transform = lambda y: 2.0 * self.softmax(y) # TODO make torch module - elif transformation == "softmax": + elif transformation == 'softmax': self.transform = self.softmax - elif transformation == "off": + elif transformation == 'off': self.transform = torch.nn.Identity() else: - raise ValueError(f"Unknown transformation type: {transformation}") + raise ValueError(f'Unknown transformation type: {transformation}') """ Compile output is the heart of the fragment probability prediction. @@ -180,8 +182,8 @@ def _compile_output(self, edge_values, graph_values, batch) -> torch.tensor: segment_ptr = [0] # cumulative boundaries per-graph (len=num_graphs+1) # Map edges to graph index (repeat left nodes according to edge dimension and retrieve graph/batch index) - edge_graph_map = batch["batch"][ - torch.repeat_interleave(batch["edge_index"][0, :], self.edge_dim) + edge_graph_map = batch['batch'][ + torch.repeat_interleave(batch['edge_index'][0, :], self.edge_dim) ] for i in range(batch.num_graphs): edges = edge_values.flatten()[ @@ -201,17 +203,17 @@ def _compile_output(self, edge_values, graph_values, batch) -> torch.tensor: ) def get_graph_embedding(self, batch): - batch["node_embedding"] = self.node_embedding(batch["x"]) - batch["edge_embedding"] = self.edge_embedding(batch["edge_attr"]) + batch['node_embedding'] = self.node_embedding(batch['x']) + batch['edge_embedding'] = self.edge_embedding(batch['edge_attr']) X = self.GNN_module(batch) pooling_func = self.precursor_module.pooling_func - return pooling_func(X, batch["batch"]) + return pooling_func(X, batch['batch']) def forward(self, batch, with_RT=False, with_CCS=False): # Embed node features - batch["node_embedding"] = self.node_embedding(batch["x"]) - batch["edge_embedding"] = self.edge_embedding(batch["edge_attr"]) + batch['node_embedding'] = self.node_embedding(batch['x']) + batch['edge_embedding'] = self.edge_embedding(batch['edge_attr']) X = self.GNN_module(batch) @@ -221,38 +223,38 @@ def forward(self, batch, with_RT=False, with_CCS=False): edge_values, graph_values, batch ) - output = {"fragment_probs": fragment_probs, "segment_ptr": segment_ptr} + output = {'fragment_probs': fragment_probs, 'segment_ptr': segment_ptr} if with_RT: - rt_values = self.RT_module(X, batch, covariate_tag="static_rt_features") - output["rt"] = rt_values + rt_values = self.RT_module(X, batch, covariate_tag='static_rt_features') + output['rt'] = rt_values if with_CCS: - ccs_values = self.CCS_module(X, batch, covariate_tag="static_rt_features") - output["ccs"] = ccs_values + ccs_values = self.CCS_module(X, batch, covariate_tag='static_rt_features') + output['ccs'] = ccs_values return output @classmethod - def load(cls, PATH: str) -> "FioraModel": + def load(cls, PATH: str) -> 'FioraModel': - with open(PATH, "rb") as f: + with open(PATH, 'rb') as f: model = dill.load(f) if not isinstance(model, cls): raise ValueError( - f"file {PATH} contains incorrect model class {type(model)}" + f'file {PATH} contains incorrect model class {type(model)}' ) return model @classmethod - def load_from_state_dict(cls, PATH: str) -> "FioraModel": + def load_from_state_dict(cls, PATH: str) -> 'FioraModel': - PARAMS_PATH = PATH.replace(".pt", "_params.json") - STATE_PATH = PATH.replace(".pt", "_state.pt") + PARAMS_PATH = PATH.replace('.pt', '_params.json') + STATE_PATH = PATH.replace('.pt', '_state.pt') - with open(PARAMS_PATH, "r") as fp: + with open(PARAMS_PATH, 'r') as fp: params = json.load(fp) model = FioraModel(params) model.load_state_dict( @@ -265,26 +267,26 @@ def load_from_state_dict(cls, PATH: str) -> "FioraModel": if not isinstance(model, cls): raise ValueError( - f"file {PATH} contains incorrect model class {type(model)}" + f'file {PATH} contains incorrect model class {type(model)}' ) return model - def save(self, PATH: str, dev: str = "cpu") -> None: + def save(self, PATH: str, dev: str = 'cpu') -> None: prev_device = next(self.parameters()).device # Set device to cpu for saving self.to(dev) - with open(PATH, "wb") as f: + with open(PATH, 'wb') as f: dill.dump(self.to(dev), f) # Save state_dict and parameters as backup - PATH = ".".join(PATH.split(".")[:-1]) + "_params.json" - with open(PATH, "w") as fp: + PATH = '.'.join(PATH.split('.')[:-1]) + '_params.json' + with open(PATH, 'w') as fp: json.dump(self.model_params, fp) - PATH = PATH.replace("_params.json", "_state.pt") + PATH = PATH.replace('_params.json', '_state.pt') torch.save(self.to(dev).state_dict(), PATH) # Reset to previous device diff --git a/fiora/GNN/GNN.py b/fiora/GNN/GNN.py index a28689a..fa13081 100644 --- a/fiora/GNN/GNN.py +++ b/fiora/GNN/GNN.py @@ -1,41 +1,42 @@ +from typing import Literal + import torch import torch_geometric.nn as geom_nn -from typing import Literal """ Geometric Models """ GeometricLayer = { - "GraphConv": { - "Layer": geom_nn.GraphConv, - "divide_output_dim": False, - "const_args": {"aggr": "mean"}, - "batch_args": {"edge_index": "edge_index"}, + 'GraphConv': { + 'Layer': geom_nn.GraphConv, + 'divide_output_dim': False, + 'const_args': {'aggr': 'mean'}, + 'batch_args': {'edge_index': 'edge_index'}, }, - "GAT": { - "Layer": geom_nn.GATConv, - "divide_output_dim": True, - "const_args": {"heads": 5}, - "batch_args": {"edge_index": "edge_index", "edge_attr": "edge_embedding"}, + 'GAT': { + 'Layer': geom_nn.GATConv, + 'divide_output_dim': True, + 'const_args': {'heads': 5}, + 'batch_args': {'edge_index': 'edge_index', 'edge_attr': 'edge_embedding'}, }, - "RGCNConv": { - "Layer": geom_nn.RGCNConv, - "divide_output_dim": False, - "const_args": {"aggr": "mean", "num_relations": 4}, - "batch_args": {"edge_index": "edge_index", "edge_type": "edge_type"}, + 'RGCNConv': { + 'Layer': geom_nn.RGCNConv, + 'divide_output_dim': False, + 'const_args': {'aggr': 'mean', 'num_relations': 4}, + 'batch_args': {'edge_index': 'edge_index', 'edge_type': 'edge_type'}, }, - "TransformerConv": { - "Layer": geom_nn.TransformerConv, - "divide_output_dim": True, - "const_args": {"heads": 8, "edge_dim": 300}, - "batch_args": {"edge_index": "edge_index", "edge_attr": "edge_embedding"}, + 'TransformerConv': { + 'Layer': geom_nn.TransformerConv, + 'divide_output_dim': True, + 'const_args': {'heads': 8, 'edge_dim': 300}, + 'batch_args': {'edge_index': 'edge_index', 'edge_attr': 'edge_embedding'}, }, - "CGConv": { - "Layer": geom_nn.CGConv, - "divide_output_dim": False, - "const_args": {"aggr": "mean"}, # , 'dim': 300}, - "batch_args": {"edge_index": "edge_index", "edge_attr": "edge_embedding"}, + 'CGConv': { + 'Layer': geom_nn.CGConv, + 'divide_output_dim': False, + 'const_args': {'aggr': 'mean'}, # , 'dim': 300}, + 'batch_args': {'edge_index': 'edge_index', 'edge_attr': 'edge_embedding'}, }, } @@ -50,10 +51,10 @@ def __init__( hidden_features: int, depth: int, embedding_dim: int = None, - embedding_aggregation_type: str = "concat", + embedding_aggregation_type: str = 'concat', gnn_type: Literal[ - "GraphConv", "GAT", "RGCNConv", "TransformerConv", "CGConv" - ] = "RGCNConv", + 'GraphConv', 'GAT', 'RGCNConv', 'TransformerConv', 'CGConv' + ] = 'RGCNConv', layer_norm: bool = False, residual_connections: bool = False, layer_stacking: bool = False, @@ -88,15 +89,15 @@ def __init__( self.layer_norms = torch.nn.ModuleList() for _ in range(depth): layers += [ - GeometricLayer[gnn_type]["Layer"]( + GeometricLayer[gnn_type]['Layer']( node_features, int( hidden_features - / GeometricLayer[gnn_type]["const_args"]["heads"] + / GeometricLayer[gnn_type]['const_args']['heads'] ) - if GeometricLayer[gnn_type]["divide_output_dim"] + if GeometricLayer[gnn_type]['divide_output_dim'] else hidden_features, - **GeometricLayer[gnn_type]["const_args"], + **GeometricLayer[gnn_type]['const_args'], ) ] if layer_norm: @@ -107,7 +108,7 @@ def __init__( def forward(self, batch): # Initialize node embeddings - X = batch["node_embedding"] + X = batch['node_embedding'] X = self.input_dropout(X) # If layer stacking is enabled, stack the node features @@ -116,7 +117,7 @@ def forward(self, batch): # Apply graph layers batch_args = { key: batch[value] - for key, value in GeometricLayer[self.gnn_type]["batch_args"].items() + for key, value in GeometricLayer[self.gnn_type]['batch_args'].items() } for i, layer in enumerate(self.graph_layers): X_skip = X @@ -139,7 +140,7 @@ def get_embedding_dimension(self): """Get the output dimension of the GNN.""" if len(self.graph_layers) == 0: if self.input_embedding_dim is None: - raise ValueError("embedding_dim must be provided when depth=0.") + raise ValueError('embedding_dim must be provided when depth=0.') return self.input_embedding_dim return self.graph_layers[-1].out_channels * ( len(self.graph_layers) + 1 if self.layer_stacking else 1 diff --git a/fiora/GNN/GraphPropertyPredictor.py b/fiora/GNN/GraphPropertyPredictor.py index 6e386ff..c06e49b 100644 --- a/fiora/GNN/GraphPropertyPredictor.py +++ b/fiora/GNN/GraphPropertyPredictor.py @@ -1,8 +1,8 @@ +from typing import Literal + import torch import torch_geometric.nn as geom_nn -from typing import Literal - class GraphPropertyPredictor(torch.nn.Module): def __init__( @@ -13,7 +13,7 @@ def __init__( dense_depth: int = 0, dense_dim: int = None, residual_connections: bool = False, - pooling_func: Literal["avg", "max"] = "avg", + pooling_func: Literal['avg', 'max'] = 'avg', input_dropout: float = 0, latent_dropout: float = 0, ) -> None: @@ -33,7 +33,7 @@ def __init__( self.activation = torch.nn.ELU() self.pooling_func = ( geom_nn.global_mean_pool - if pooling_func == "avg" + if pooling_func == 'avg' else geom_nn.global_max_pool ) self.input_dropout = torch.nn.Dropout(input_dropout) @@ -45,7 +45,7 @@ def __init__( hidden_dimension = dense_dim if dense_dim is not None else num_features if hidden_dimension != num_features and residual_connections: raise NotImplementedError( - "Residual connections require the hidden dimension to match the input dimension." + 'Residual connections require the hidden dimension to match the input dimension.' ) for _ in range(dense_depth): dense_layers += [torch.nn.Linear(num_features, hidden_dimension)] @@ -54,8 +54,8 @@ def __init__( self.output_layer = torch.nn.Linear(num_features, out_dimension) - def forward(self, X, batch, covariate_tag="static_graph_features"): - X = self.pooling_func(X, batch["batch"]) + def forward(self, X, batch, covariate_tag='static_graph_features'): + X = self.pooling_func(X, batch['batch']) X = torch.cat( [X, batch[covariate_tag]], axis=-1 ) # self.input_dropout(batch["static_graph_features"]) diff --git a/fiora/GNN/Losses.py b/fiora/GNN/Losses.py index 1e0d158..93774a5 100644 --- a/fiora/GNN/Losses.py +++ b/fiora/GNN/Losses.py @@ -18,8 +18,8 @@ def forward(self, input, target, weight): class WeightedMSEMetric(Metric): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("numel", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('numel', default=torch.tensor(0), dist_reduce_fx='sum') def update(self, preds: Tensor, target: Tensor, weight: Tensor) -> None: self.sum += (weight * (preds - target) ** 2).sum() @@ -44,8 +44,8 @@ def forward(self, input, target, weight): class WeightedMAEMetric(Metric): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("numel", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('numel', default=torch.tensor(0), dist_reduce_fx='sum') def update(self, preds: Tensor, target: Tensor, weight: Tensor) -> None: self.sum += (weight * torch.abs(preds - target)).sum() @@ -61,7 +61,7 @@ class GraphwiseKLLoss(torch.nn.Module): def __init__( self, eps: float = 1e-8, - reduction: str = "mean", + reduction: str = 'mean', normalize_targets: bool = True, normalize_pred: bool = True, ): @@ -79,7 +79,7 @@ def forward( weight: torch.Tensor = None, ): assert segment_ptr.dim() == 1 and segment_ptr.numel() >= 2, ( - "segment_ptr must be 1D with at least 2 entries" + 'segment_ptr must be 1D with at least 2 entries' ) num_graphs = segment_ptr.numel() - 1 @@ -105,9 +105,9 @@ def forward( total = total + kl total_el += r - l - if self.reduction == "sum": + if self.reduction == 'sum': return total - elif self.reduction == "mean_edge": + elif self.reduction == 'mean_edge': return total / max(total_el, 1) else: return total / max(num_graphs, 1) @@ -117,7 +117,7 @@ class GraphwiseKLLossMetric(Metric): def __init__( self, eps: float = 1e-8, - reduction: str = "mean", + reduction: str = 'mean', normalize_targets: bool = True, normalize_pred: bool = True, **kwargs, @@ -127,16 +127,16 @@ def __init__( self.reduction = reduction self.normalize_targets = normalize_targets self.normalize_pred = normalize_pred - self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state('total', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state( - "total_graphs", + 'total_graphs', default=torch.tensor(0, dtype=torch.long), - dist_reduce_fx="sum", + dist_reduce_fx='sum', ) self.add_state( - "total_elements", + 'total_elements', default=torch.tensor(0, dtype=torch.long), - dist_reduce_fx="sum", + dist_reduce_fx='sum', ) def update( @@ -180,9 +180,9 @@ def update( self.total_elements += torch.tensor(total_el, device=self.total_elements.device) def compute(self) -> Tensor: - if self.reduction == "sum": + if self.reduction == 'sum': return self.total - elif self.reduction == "mean_edge": + elif self.reduction == 'mean_edge': return self.total / torch.clamp(self.total_elements.float(), min=1.0) else: return self.total / torch.clamp(self.total_graphs.float(), min=1.0) diff --git a/fiora/GNN/Trainer.py b/fiora/GNN/Trainer.py index 6666903..e6a0f4a 100644 --- a/fiora/GNN/Trainer.py +++ b/fiora/GNN/Trainer.py @@ -1,18 +1,19 @@ from abc import ABC, abstractmethod -import torch +from typing import Any, Dict, List, Literal + import numpy as np +import torch +from sklearn.model_selection import train_test_split from torch.utils.data import Dataset from torchmetrics import ( Accuracy, - MetricTracker, + MeanAbsoluteError, + MeanSquaredError, MetricCollection, + MetricTracker, Precision, Recall, - MeanSquaredError, - MeanAbsoluteError, ) -from sklearn.model_selection import train_test_split -from typing import Literal, List, Dict, Any class Trainer(ABC): @@ -26,7 +27,7 @@ def __init__( val_keys: List[int] | None = None, seed: int = 42, num_workers: int = 0, - device: str = "cpu", + device: str = 'cpu', ) -> None: self.only_training = only_training @@ -56,11 +57,11 @@ def _split_by_group( ): train_keys = train_keys or [] val_keys = val_keys or [] - group_ids = [getattr(x, "group_id") for x in data] + group_ids = [getattr(x, 'group_id') for x in data] keys = np.unique(group_ids) if len(train_keys) > 0 and len(val_keys) > 0: self.train_keys, self.val_keys = train_keys, val_keys - print("Using pre-set train/validation keys") + print('Using pre-set train/validation keys') else: self.train_keys, self.val_keys = train_test_split( keys, test_size=1 - train_val_split, random_state=seed @@ -72,49 +73,49 @@ def _split_by_group( def _get_default_metrics( self, - problem_type: Literal["classification", "regression", "softmax_regression"], + problem_type: Literal['classification', 'regression', 'softmax_regression'], ): metrics = { data_split: MetricTracker( MetricCollection( { - "acc": Accuracy("binary", num_classes=1), - "prec": Precision("binary", num_classes=1), - "rec": Recall("binary", num_classes=1), + 'acc': Accuracy('binary', num_classes=1), + 'prec': Precision('binary', num_classes=1), + 'rec': Recall('binary', num_classes=1), } ) - if problem_type == "classification" + if problem_type == 'classification' else MetricCollection( - {"mse": MeanSquaredError(), "mae": MeanAbsoluteError()} + {'mse': MeanSquaredError(), 'mae': MeanAbsoluteError()} ) ).to(self.device) - for data_split in ["train", "val", "masked_val", "test"] + for data_split in ['train', 'val', 'masked_val', 'test'] } return metrics def _init_checkpoint_system(self, save_path: str) -> None: self.checkpoint_stats = { - "epoch": -1, - "val_loss": float("inf"), - "sqrt_val_loss": float("inf"), - "file": save_path, + 'epoch': -1, + 'val_loss': float('inf'), + 'sqrt_val_loss': float('inf'), + 'file': save_path, } def _update_checkpoint( self, new_checkpoint_data: Dict[str, Any], model, save_checkpoint: bool = True ) -> None: self.checkpoint_stats.update(new_checkpoint_data) - model.save(self.checkpoint_stats["file"]) + model.save(self.checkpoint_stats['file']) def _init_history(self) -> None: self.history = { - "epoch": [], - "train_error": [], - "sqrt_train_error": [], - "val_error": [], - "sqrt_val_error": [], - "lr": [], + 'epoch': [], + 'train_error': [], + 'sqrt_train_error': [], + 'val_error': [], + 'sqrt_val_error': [], + 'lr': [], } @staticmethod @@ -124,24 +125,24 @@ def _to_float(value): return float(value) def _extract_primary_error(self, stats): - if "mse" in stats: - mse = self._to_float(stats["mse"]) + if 'mse' in stats: + mse = self._to_float(stats['mse']) return mse, float(np.sqrt(mse)) - if "mae" in stats: - mae = self._to_float(stats["mae"]) - return mae, float("nan") + if 'mae' in stats: + mae = self._to_float(stats['mae']) + return mae, float('nan') key = next(iter(stats.keys())) - return self._to_float(stats[key]), float("nan") + return self._to_float(stats[key]), float('nan') def _update_history(self, epoch, train_stats, val_stats, lr) -> None: train_error, train_sqrt_error = self._extract_primary_error(train_stats) val_error, val_sqrt_error = self._extract_primary_error(val_stats) - self.history["epoch"].append(epoch) - self.history["train_error"].append(train_error) - self.history["sqrt_train_error"].append(train_sqrt_error) - self.history["val_error"].append(val_error) - self.history["sqrt_val_error"].append(val_sqrt_error) - self.history["lr"].append(lr) + self.history['epoch'].append(epoch) + self.history['train_error'].append(train_error) + self.history['sqrt_train_error'].append(train_sqrt_error) + self.history['val_error'].append(val_error) + self.history['sqrt_val_error'].append(val_sqrt_error) + self.history['lr'].append(lr) def is_group_in_training_set(self, group_id): return group_id in self.train_keys diff --git a/fiora/GNN/fabric_training.py b/fiora/GNN/fabric_training.py index 8278fe8..2127eec 100644 --- a/fiora/GNN/fabric_training.py +++ b/fiora/GNN/fabric_training.py @@ -33,11 +33,11 @@ def build_loss_kwargs( weight_tensor = ( weight_tensor_override if weight_tensor_override is not None - else batch["weight_tensor"] + else batch['weight_tensor'] ) - kwargs["weight"] = weight_tensor if mask is None else weight_tensor[mask] - if include_segment_ptr and getattr(loss_fn, "requires_segment_ptr", False): - kwargs["segment_ptr"] = y_pred.get("segment_ptr") + kwargs['weight'] = weight_tensor if mask is None else weight_tensor[mask] + if include_segment_ptr and getattr(loss_fn, 'requires_segment_ptr', False): + kwargs['segment_ptr'] = y_pred.get('segment_ptr') return kwargs @@ -53,19 +53,19 @@ def add_rt_ccs_loss( if with_rt: kwargs_rt = {} if with_weights: - kwargs_rt["weight"] = batch["weight"][batch["retention_mask"]] + kwargs_rt['weight'] = batch['weight'][batch['retention_mask']] loss = loss + loss_fn( - y_pred["rt"][batch["retention_mask"]], - batch["retention_time"][batch["retention_mask"]], + y_pred['rt'][batch['retention_mask']], + batch['retention_time'][batch['retention_mask']], **kwargs_rt, ) if with_ccs: kwargs_ccs = {} if with_weights: - kwargs_ccs["weight"] = batch["weight"][batch["ccs_mask"]] + kwargs_ccs['weight'] = batch['weight'][batch['ccs_mask']] loss = loss + loss_fn( - y_pred["ccs"][batch["ccs_mask"]], - batch["ccs"][batch["ccs_mask"]], + y_pred['ccs'][batch['ccs_mask']], + batch['ccs'][batch['ccs_mask']], **kwargs_ccs, ) return loss @@ -73,7 +73,7 @@ def add_rt_ccs_loss( def safe_metric_update(metric, preds, target, kwargs: dict | None = None): kwargs = kwargs or {} - update = getattr(metric, "update", None) + update = getattr(metric, 'update', None) if callable(update): try: update(preds, target, **kwargs) @@ -90,7 +90,7 @@ def safe_metric_update(metric, preds, target, kwargs: dict | None = None): def metric_label_and_value(metric_or_stats, preferred_key: str | None = None): stats = ( metric_or_stats.compute() - if hasattr(metric_or_stats, "compute") + if hasattr(metric_or_stats, 'compute') else metric_or_stats ) @@ -98,7 +98,7 @@ def metric_label_and_value(metric_or_stats, preferred_key: str | None = None): if preferred_key is not None and preferred_key in stats: key = preferred_key else: - for candidate in ("kl", "mse", "mae", "acc"): + for candidate in ('kl', 'mse', 'mae', 'acc'): if candidate in stats: key = candidate break @@ -106,11 +106,11 @@ def metric_label_and_value(metric_or_stats, preferred_key: str | None = None): key = next(iter(stats.keys())) value = stats[key] else: - key = preferred_key or "metric" + key = preferred_key or 'metric' value = stats - label = "rmse" if key == "mse" else key - if key == "mse": + label = 'rmse' if key == 'mse' else key + if key == 'mse': value = torch.sqrt(value) if isinstance(value, torch.Tensor): value = float(value.detach().cpu().item()) @@ -120,13 +120,13 @@ def metric_label_and_value(metric_or_stats, preferred_key: str | None = None): def resolve_fabric_runtime(device: str): - if device.startswith("cuda"): - if ":" in device: - return "cuda", [int(device.split(":")[-1])] - return "cuda", 1 - if device.startswith("mps"): - return "mps", 1 - return "cpu", 1 + if device.startswith('cuda'): + if ':' in device: + return 'cuda', [int(device.split(':')[-1])] + return 'cuda', 1 + if device.startswith('mps'): + return 'mps', 1 + return 'cpu', 1 def seed_everything(seed: int) -> None: @@ -137,7 +137,7 @@ def seed_everything(seed: int) -> None: torch.cuda.manual_seed_all(seed) -def build_progress_iterator(dataloader, enabled=False, desc=""): +def build_progress_iterator(dataloader, enabled=False, desc=''): if not enabled: return dataloader try: @@ -149,7 +149,7 @@ def build_progress_iterator(dataloader, enabled=False, desc=""): def unwrap_model(model): - return model.module if hasattr(model, "module") else model + return model.module if hasattr(model, 'module') else model def apply_precursor_loss_weight( @@ -196,9 +196,9 @@ def run_epoch( rt_metric: bool, optimizer=None, use_validation_mask: bool = False, - mask_name: str = "validation_mask", + mask_name: str = 'validation_mask', show_progress: bool = False, - progress_desc: str = "", + progress_desc: str = '', non_blocking_transfer: bool = False, precursor_loss_weight: float = 1.0, ): @@ -222,13 +222,13 @@ def run_epoch( with torch.set_grad_enabled(is_training): y_pred = model(batch, with_RT=with_rt, with_CCS=with_ccs) use_weight_vector = with_weights or getattr( - loss_fn, "requires_segment_ptr", False + loss_fn, 'requires_segment_ptr', False ) weight_tensor = None if use_weight_vector: weight_tensor = apply_precursor_loss_weight( - batch["weight_tensor"], - y_pred.get("segment_ptr"), + batch['weight_tensor'], + y_pred.get('segment_ptr'), precursor_loss_weight, ) @@ -245,14 +245,14 @@ def run_epoch( weight_tensor_override=weight_tensor, ) loss = loss_fn( - y_pred["fragment_probs"][mask], + y_pred['fragment_probs'][mask], batch[y_tag][mask], **kwargs, ) if not rt_metric: safe_metric_update( metric, - y_pred["fragment_probs"][mask], + y_pred['fragment_probs'][mask], batch[y_tag][mask], kwargs, ) @@ -260,15 +260,15 @@ def run_epoch( if with_rt: safe_metric_update( metric, - y_pred["rt"][batch["retention_mask"]], - batch["retention_time"][batch["retention_mask"]], + y_pred['rt'][batch['retention_mask']], + batch['retention_time'][batch['retention_mask']], {}, ) if with_ccs: safe_metric_update( metric, - y_pred["ccs"][batch["ccs_mask"]], - batch["ccs"][batch["ccs_mask"]], + y_pred['ccs'][batch['ccs_mask']], + batch['ccs'][batch['ccs_mask']], {}, ) loss = add_rt_ccs_loss( @@ -292,24 +292,24 @@ def run_epoch( include_segment_ptr=True, weight_tensor_override=weight_tensor, ) - loss = loss_fn(y_pred["fragment_probs"], batch[y_tag], **kwargs) + loss = loss_fn(y_pred['fragment_probs'], batch[y_tag], **kwargs) if not rt_metric: safe_metric_update( - metric, y_pred["fragment_probs"], batch[y_tag], kwargs + metric, y_pred['fragment_probs'], batch[y_tag], kwargs ) else: if with_rt: safe_metric_update( metric, - y_pred["rt"][batch["retention_mask"]], - batch["retention_time"][batch["retention_mask"]], + y_pred['rt'][batch['retention_mask']], + batch['retention_time'][batch['retention_mask']], {}, ) if with_ccs: safe_metric_update( metric, - y_pred["ccs"][batch["ccs_mask"]], - batch["ccs"][batch["ccs_mask"]], + y_pred['ccs'][batch['ccs_mask']], + batch['ccs'][batch['ccs_mask']], {}, ) @@ -330,7 +330,7 @@ def run_epoch( fabric.backward(loss) optimizer.step() - avg_loss = loss_total / max(loss_batches, 1) if loss_batches > 0 else float("nan") + avg_loss = loss_total / max(loss_batches, 1) if loss_batches > 0 else float('nan') metric_label, metric_value = metric_label_and_value( metric, preferred_key=metric_name ) @@ -353,12 +353,12 @@ class TrainingState: def _init_history() -> dict: return { - "epoch": [], - "train_error": [], - "sqrt_train_error": [], - "val_error": [], - "sqrt_val_error": [], - "lr": [], + 'epoch': [], + 'train_error': [], + 'sqrt_train_error': [], + 'val_error': [], + 'sqrt_val_error': [], + 'lr': [], } @@ -369,20 +369,20 @@ def _record_history( train_result: EpochResult | None, val_result: EpochResult | None, ) -> None: - history["epoch"].append(epoch) - history["train_error"].append( - train_result.metric_value if train_result is not None else float("nan") + history['epoch'].append(epoch) + history['train_error'].append( + train_result.metric_value if train_result is not None else float('nan') ) - history["sqrt_train_error"].append( - train_result.metric_value if train_result is not None else float("nan") + history['sqrt_train_error'].append( + train_result.metric_value if train_result is not None else float('nan') ) - history["val_error"].append( - val_result.metric_value if val_result is not None else float("nan") + history['val_error'].append( + val_result.metric_value if val_result is not None else float('nan') ) - history["sqrt_val_error"].append( - val_result.metric_value if val_result is not None else float("nan") + history['sqrt_val_error'].append( + val_result.metric_value if val_result is not None else float('nan') ) - history["lr"].append(lr) + history['lr'].append(lr) def _run_train_epoch( @@ -491,15 +491,15 @@ def _step_scheduler( ) -> None: if scheduler is None: return - prev_lr = optimizer.param_groups[0]["lr"] + prev_lr = optimizer.param_groups[0]['lr'] if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): if monitor_metric is not None and not np.isnan(monitor_metric): scheduler.step(monitor_metric) else: scheduler.step() - curr_lr = optimizer.param_groups[0]["lr"] + curr_lr = optimizer.param_groups[0]['lr'] if logger is not None and fabric.is_global_zero and curr_lr < prev_lr: - logger(f"\t >> Learning rate reduced from {prev_lr:1.0e} to {curr_lr:1.0e}") + logger(f'\t >> Learning rate reduced from {prev_lr:1.0e} to {curr_lr:1.0e}') def _maybe_update_best( @@ -523,9 +523,9 @@ def _maybe_update_best( unwrap_model(model).save(output_path) if logger is not None: if baseline: - logger("\t >> Set baseline checkpoint to epoch 0") + logger('\t >> Set baseline checkpoint to epoch 0') else: - logger(f"\t >> Set new checkpoint to epoch {epoch}") + logger(f'\t >> Set new checkpoint to epoch {epoch}') def _log_epoch( @@ -543,41 +543,41 @@ def _log_epoch( train_label = ( train_result.metric_label if train_result is not None - else (val_result.metric_label if val_result is not None else "metric") + else (val_result.metric_label if val_result is not None else 'metric') ) val_label = ( val_result.metric_label if val_result is not None - else (train_result.metric_label if train_result is not None else "metric") + else (train_result.metric_label if train_result is not None else 'metric') ) train_loss_str = ( - f"{train_result.loss:.4f}" + f'{train_result.loss:.4f}' if train_result is not None and not np.isnan(train_result.loss) - else "n/a" + else 'n/a' ) val_loss_str = ( - f"{val_result.loss:.4f}" + f'{val_result.loss:.4f}' if val_result is not None and not np.isnan(val_result.loss) - else "n/a" + else 'n/a' ) train_metric_str = ( - f"{train_result.metric_value:.4f}" + f'{train_result.metric_value:.4f}' if train_result is not None and not np.isnan(train_result.metric_value) - else "n/a" + else 'n/a' ) val_metric_str = ( - f"{val_result.metric_value:.4f}" + f'{val_result.metric_value:.4f}' if val_result is not None and not np.isnan(val_result.metric_value) - else "n/a" + else 'n/a' ) logger( - f"Epoch {epoch}/{epochs} - " - f"loss: {train_loss_str} - " - f"val_loss: {val_loss_str} - " - f"train_{train_label}: {train_metric_str} - " - f"val_{val_label}: {val_metric_str}" + f'Epoch {epoch}/{epochs} - ' + f'loss: {train_loss_str} - ' + f'val_loss: {val_loss_str} - ' + f'train_{train_label}: {train_metric_str} - ' + f'val_{val_label}: {val_metric_str}' ) @@ -616,15 +616,15 @@ def train_fabric_loop( has_validation = len(val_data) > 0 accelerator, devices = resolve_fabric_runtime(device) warnings.filterwarnings( - "ignore", - message=r"The `srun` command is available on your system but is not used\..*", + 'ignore', + message=r'The `srun` command is available on your system but is not used\..*', category=PossibleUserWarning, ) - if accelerator == "cuda": - torch.set_float32_matmul_precision("high") + if accelerator == 'cuda': + torch.set_float32_matmul_precision('high') if pin_memory is None: - pin_memory = accelerator == "cuda" - use_non_blocking_transfer = bool(pin_memory and accelerator == "cuda") + pin_memory = accelerator == 'cuda' + use_non_blocking_transfer = bool(pin_memory and accelerator == 'cuda') fabric = Fabric(accelerator=accelerator, devices=devices) if launch_fabric: @@ -636,7 +636,7 @@ def train_fabric_loop( train_metric = metric_cls().to(fabric.device) val_metric = metric_cls().to(fabric.device) else: - metric_name = "mse" + metric_name = 'mse' train_metric = MeanSquaredError().to(fabric.device) val_metric = MeanSquaredError().to(fabric.device) @@ -644,12 +644,12 @@ def train_fabric_loop( optimizer = torch.optim.Adam( model.parameters(), lr=learning_rate, weight_decay=weight_decay ) - if scheduler is None and scheduler_name == "plateau": + if scheduler is None and scheduler_name == 'plateau': scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=scheduler_patience, factor=scheduler_factor, - mode="min", + mode='min', ) train_loader = geom_loader.DataLoader( @@ -679,7 +679,7 @@ def train_fabric_loop( show_val_progress = has_validation and (len(val_data) > progress_threshold) state = TrainingState( - best_metric=float("inf"), best_epoch=-1, history=_init_history() + best_metric=float('inf'), best_epoch=-1, history=_init_history() ) if has_validation: @@ -698,7 +698,7 @@ def train_fabric_loop( use_validation_mask=use_validation_mask, validation_mask_name=validation_mask_name, show_progress=show_val_progress, - progress_desc=f"Val 0/{epochs}", + progress_desc=f'Val 0/{epochs}', non_blocking_transfer=use_non_blocking_transfer, precursor_loss_weight=precursor_loss_weight, ) @@ -706,7 +706,7 @@ def train_fabric_loop( _record_history( state.history, epoch=0, - lr=optimizer.param_groups[0]["lr"], + lr=optimizer.param_groups[0]['lr'], train_result=None, val_result=baseline_result, ) @@ -744,7 +744,7 @@ def train_fabric_loop( rt_metric=rt_metric, optimizer=optimizer, show_progress=show_train_progress, - progress_desc=f"Train {epoch}/{epochs}", + progress_desc=f'Train {epoch}/{epochs}', non_blocking_transfer=use_non_blocking_transfer, precursor_loss_weight=precursor_loss_weight, ) @@ -767,7 +767,7 @@ def train_fabric_loop( use_validation_mask=use_validation_mask, validation_mask_name=validation_mask_name, show_progress=show_val_progress, - progress_desc=f"Val {epoch}/{epochs}", + progress_desc=f'Val {epoch}/{epochs}', non_blocking_transfer=use_non_blocking_transfer, precursor_loss_weight=precursor_loss_weight, ) @@ -800,7 +800,7 @@ def train_fabric_loop( _record_history( state.history, epoch=epoch, - lr=optimizer.param_groups[0]["lr"], + lr=optimizer.param_groups[0]['lr'], train_result=train_result, val_result=val_result if is_val_cycle else None, ) @@ -815,14 +815,14 @@ def train_fabric_loop( if state.best_epoch < 0: state.best_epoch = epochs - state.best_metric = float("nan") + state.best_metric = float('nan') if output_path is not None and fabric.is_global_zero: unwrap_model(model).save(output_path) checkpoints = { - "epoch": state.best_epoch, - "val_loss": state.best_metric, - "sqrt_val_loss": state.best_metric, - "file": output_path, + 'epoch': state.best_epoch, + 'val_loss': state.best_metric, + 'sqrt_val_loss': state.best_metric, + 'file': output_path, } return checkpoints, state.history diff --git a/fiora/IO/LibraryLoader.py b/fiora/IO/LibraryLoader.py index 2c74f04..19ca896 100644 --- a/fiora/IO/LibraryLoader.py +++ b/fiora/IO/LibraryLoader.py @@ -9,7 +9,7 @@ def load_from_csv(self, path): return pd.read_csv(path, index_col=[0], low_memory=False) def load_from_msp(self): - raise NotImplementedError("MSP loading is not implemented yet.") + raise NotImplementedError('MSP loading is not implemented yet.') def clean_library(self): - raise NotImplementedError("Library cleaning is not implemented yet.") + raise NotImplementedError('Library cleaning is not implemented yet.') diff --git a/fiora/IO/cfmReader.py b/fiora/IO/cfmReader.py index e3ddfca..175e9e7 100644 --- a/fiora/IO/cfmReader.py +++ b/fiora/IO/cfmReader.py @@ -1,58 +1,58 @@ import pandas as pd -def read(source, sep: str = " ", as_df=False): - file = open(source, "r") +def read(source, sep: str = ' ', as_df=False): + file = open(source, 'r') data = [] data_piece = {} - precursor = "" + precursor = '' mz, intensity, annotation = [], [], [] energy2 = False for line in file: - if line == "\n": + if line == '\n': if energy2: - data_piece["peaks40"] = { - "mz": mz, - "intensity": intensity, - "annotation": annotation, + data_piece['peaks40'] = { + 'mz': mz, + 'intensity': intensity, + 'annotation': annotation, } mz, intensity, annotation = [], [], [] energy2 = False continue else: continue - if line.startswith("#PREDICTED"): + if line.startswith('#PREDICTED'): continue - if line.startswith("#In-silico"): - precursor = line.split("ESI-MS/MS ")[1].split(" Spectra")[0] + if line.startswith('#In-silico'): + precursor = line.split('ESI-MS/MS ')[1].split(' Spectra')[0] continue - if line.strip().startswith("#ID="): + if line.strip().startswith('#ID='): energy2 = False data.append(data_piece) data_piece, mz, intensity, annotation = {}, [], [], [] - data_piece["Precursor_type"] = precursor - if "=" in line: - key = line.split("=")[0] - value = "=".join(line.strip().split("=", 1)[1:]) + data_piece['Precursor_type'] = precursor + if '=' in line: + key = line.split('=')[0] + value = '='.join(line.strip().split('=', 1)[1:]) data_piece[key] = value - elif line.strip() == "energy0": + elif line.strip() == 'energy0': continue - elif line.strip() == "energy1": - data_piece["peaks10"] = { - "mz": mz, - "intensity": intensity, - "annotation": annotation, + elif line.strip() == 'energy1': + data_piece['peaks10'] = { + 'mz': mz, + 'intensity': intensity, + 'annotation': annotation, } mz, intensity, annotation = [], [], [] # new data piece - elif line.strip() == "energy2": + elif line.strip() == 'energy2': energy2 = True - data_piece["peaks20"] = { - "mz": mz, - "intensity": intensity, - "annotation": annotation, + data_piece['peaks20'] = { + 'mz': mz, + 'intensity': intensity, + 'annotation': annotation, } mz, intensity, annotation = [], [], [] else: diff --git a/fiora/IO/fraggraphReader.py b/fiora/IO/fraggraphReader.py index ebfa59e..4a53863 100644 --- a/fiora/IO/fraggraphReader.py +++ b/fiora/IO/fraggraphReader.py @@ -5,16 +5,16 @@ def parser_fraggraph_gen(output_file): with open(output_file) as t: output = t.readlines() - output = [s.replace("\n", "") for s in output] + output = [s.replace('\n', '') for s in output] nfrags = int(output[0]) - frag_index = [int(output[i].split(" ")[0]) for i in range(1, nfrags + 1)] - frag_mass = [float(output[i].split(" ")[1]) for i in range(1, nfrags + 1)] - frag_smiles = [output[i].split(" ")[2] for i in range(1, nfrags + 1)] - loss_from = [int(output[i].split(" ")[0]) for i in range(nfrags + 2, len(output))] - loss_to = [int(output[i].split(" ")[1]) for i in range(nfrags + 2, len(output))] - loss_smiles = [output[i].split(" ")[2] for i in range(nfrags + 2, len(output))] + frag_index = [int(output[i].split(' ')[0]) for i in range(1, nfrags + 1)] + frag_mass = [float(output[i].split(' ')[1]) for i in range(1, nfrags + 1)] + frag_smiles = [output[i].split(' ')[2] for i in range(1, nfrags + 1)] + loss_from = [int(output[i].split(' ')[0]) for i in range(nfrags + 2, len(output))] + loss_to = [int(output[i].split(' ')[1]) for i in range(nfrags + 2, len(output))] + loss_smiles = [output[i].split(' ')[2] for i in range(nfrags + 2, len(output))] fragments = pd.DataFrame( - {"index": frag_index, "mass": frag_mass, "smiles": frag_smiles} + {'index': frag_index, 'mass': frag_mass, 'smiles': frag_smiles} ) - losses = pd.DataFrame({"from": loss_from, "to": loss_to, "smiles": loss_smiles}) - return {"fragments": fragments, "losses": losses} + losses = pd.DataFrame({'from': loss_from, 'to': loss_to, 'smiles': loss_smiles}) + return {'fragments': fragments, 'losses': losses} diff --git a/fiora/IO/mgfReader.py b/fiora/IO/mgfReader.py index 95048e2..242943d 100644 --- a/fiora/IO/mgfReader.py +++ b/fiora/IO/mgfReader.py @@ -2,8 +2,8 @@ import pandas as pd -def read(source, sep: str = " ", as_df=False, debug=False): - file = open(source, "r") +def read(source, sep: str = ' ', as_df=False, debug=False): + file = open(source, 'r') in_begin_ions = False data = [] data_piece = {} @@ -12,29 +12,29 @@ def read(source, sep: str = " ", as_df=False, debug=False): for line in file: if debug: print(line.strip()) - if line == "MASS=Monoisotopic\n": + if line == 'MASS=Monoisotopic\n': continue # TODO edge case hacky solution - if line == "\n": + if line == '\n': continue - if line.startswith("#"): + if line.startswith('#'): continue - if line.startswith("NA#"): + if line.startswith('NA#'): continue - if line.strip() == "END IONS": + if line.strip() == 'END IONS': in_begin_ions = False - data_piece["peaks"] = {"mz": mz, "intensity": intensity, "annotation": ion} + data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'annotation': ion} data.append(data_piece) continue - if line.strip() == "BEGIN IONS" or line.strip() == "BEGIN IONS:": + if line.strip() == 'BEGIN IONS' or line.strip() == 'BEGIN IONS:': in_begin_ions = True data_piece, mz, intensity, ion = {}, [], [], [] continue - if "=" in line: - key = line.split("=")[0] - value = "=".join( - line.strip().split("=", 1)[1:] + if '=' in line: + key = line.split('=')[0] + value = '='.join( + line.strip().split('=', 1)[1:] ) # line.split('=', 1)[1].strip() data_piece[key] = value else: @@ -53,16 +53,16 @@ def read(source, sep: str = " ", as_df=False, debug=False): def get_spectrum_by_name(source, name): - file = open(source, "r") + file = open(source, 'r') - line_match = "TITLE=" + name + "\n" + line_match = 'TITLE=' + name + '\n' data_piece = {} mz, intensity, ion = [], [], [] found = False for line in file: - if line == "END IONS\n" and found: - data_piece["peaks"] = {"mz": mz, "intensity": intensity, "ion": ion} + if line == 'END IONS\n' and found: + data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'ion': ion} break if line == line_match: # exact name match found = True @@ -70,12 +70,12 @@ def get_spectrum_by_name(source, name): if not found: continue # skip ahead - if "=" in line: - key = line.split("=")[0] - value = line.split("=", 1)[1].strip() + if '=' in line: + key = line.split('=')[0] + value = line.split('=', 1)[1].strip() data_piece[key] = value else: - line_split = line.split(" ") + line_split = line.split(' ') mz.append(line_split[0].strip()) intensity.append(line_split[1].strip()) file.close() diff --git a/fiora/IO/mgfWriter.py b/fiora/IO/mgfWriter.py index 0d149c4..38cf287 100644 --- a/fiora/IO/mgfWriter.py +++ b/fiora/IO/mgfWriter.py @@ -1,25 +1,25 @@ def write_mgf( df, path, - peak_tag="peaks", + peak_tag='peaks', write_header=True, - headers=["TITLE", "RTINSECONDS", "PEPMASS", "CHARGE"], + headers=['TITLE', 'RTINSECONDS', 'PEPMASS', 'CHARGE'], header_map={}, annotation=False, ): for h in headers: if h not in header_map.keys(): header_map[h] = h - with open(path, "w") as outfile: + with open(path, 'w') as outfile: for x in df.index: - outfile.write("BEGIN IONS\n") + outfile.write('BEGIN IONS\n') peaks = df.loc[x][peak_tag] if write_header: for key in headers: - outfile.write(key + "=" + str(df.loc[x][header_map[key]]) + "\n") - for i in range(len(peaks["mz"])): - line = str(peaks["mz"][i]) + " " + str(peaks["intensity"][i]) + outfile.write(key + '=' + str(df.loc[x][header_map[key]]) + '\n') + for i in range(len(peaks['mz'])): + line = str(peaks['mz'][i]) + ' ' + str(peaks['intensity'][i]) if annotation: - line += " " + peaks["annotation"][i] - outfile.write(line + "\n") - outfile.write("END IONS\n") + line += ' ' + peaks['annotation'][i] + outfile.write(line + '\n') + outfile.write('END IONS\n') diff --git a/fiora/IO/molReader.py b/fiora/IO/molReader.py index e0be3db..d6affec 100644 --- a/fiora/IO/molReader.py +++ b/fiora/IO/molReader.py @@ -1,12 +1,11 @@ from rdkit import Chem - """ Functions to read mol files """ def load_MOL(path): - MOL_string = open(path, "r").read() + MOL_string = open(path, 'r').read() m = Chem.MolFromMolBlock(MOL_string) return m diff --git a/fiora/IO/mspReader.py b/fiora/IO/mspReader.py index e2bc47e..176b428 100644 --- a/fiora/IO/mspReader.py +++ b/fiora/IO/mspReader.py @@ -1,26 +1,26 @@ import regex as re -def read(source, sep=" "): - file = open(source, "r") +def read(source, sep=' '): + file = open(source, 'r') data = [] data_piece = {} mz, intensity, ion = [], [], [] for line in file: - if "Name:" == line[0:5] or "NAME:" == line[0:5]: - data_piece["peaks"] = {"mz": mz, "intensity": intensity, "annotation": ion} + if 'Name:' == line[0:5] or 'NAME:' == line[0:5]: + data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'annotation': ion} data.append(data_piece) data_piece, mz, intensity, ion = {}, [], [], [] - if ":" in line: - key = line.split(":")[0] - value = line.split(":", 1)[1].strip() + if ':' in line: + key = line.split(':')[0] + value = line.split(':', 1)[1].strip() data_piece[key] = value else: - if line == "\n": + if line == '\n': continue ls = line.strip() line_split = ls.split(sep) @@ -28,7 +28,7 @@ def read(source, sep=" "): intensity.append(float(line_split[1])) # ion.append(line_split[2].strip()) - data_piece["peaks"] = {"mz": mz, "intensity": intensity, "annotation": ion} + data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'annotation': ion} data.append(data_piece) file.close() @@ -36,21 +36,21 @@ def read(source, sep=" "): def read_minimal(source): - file = open(source, "r") + file = open(source, 'r') data = [] data_piece = {} mz, intensity = [], [] for line in file: - if "Name:" == line[0:5]: - data_piece["peaks"] = {"mz": mz, "intensity": intensity} + if 'Name:' == line[0:5]: + data_piece['peaks'] = {'mz': mz, 'intensity': intensity} data.append(data_piece) - data_piece = {"Name": line.split(":", 1)[1].strip()} + data_piece = {'Name': line.split(':', 1)[1].strip()} mz, intensity = [], [] continue - if ":" not in line: - line_split = line.split("\t") + if ':' not in line: + line_split = line.split('\t') mz.append(line_split[0]) intensity.append(line_split[1]) @@ -61,13 +61,13 @@ def read_minimal(source): def read_peptides(source): - file = open(source, "r") + file = open(source, 'r') pep_list = [] for line in file: - if "Name:" == line[0:5]: - li = line.strip("\n")[5:] - li = re.sub(r"[\d+ /]", "", li) + if 'Name:' == line[0:5]: + li = line.strip('\n')[5:] + li = re.sub(r'[\d+ /]', '', li) pep_list.append(li) file.close() @@ -75,20 +75,20 @@ def read_peptides(source): def read_sparse(source): - file = open(source, "r") + file = open(source, 'r') file.close() def readOld(source): - file = open(source, "r") + file = open(source, 'r') data = [] active_lines = [] for line in file: - if "Name:" == line[0:5]: + if 'Name:' == line[0:5]: data.append(make_data_piece(active_lines)) active_lines = [] - active_lines.append(line.strip("\n")) + active_lines.append(line.strip('\n')) data.append(make_data_piece(active_lines)) file.close() @@ -100,42 +100,42 @@ def make_data_piece(lines): mz, intensity, ion = [], [], [] for line in lines: - if ":" in line: - key = line.split(":")[0] - value = ":".join(line.split(":")[1:]) + if ':' in line: + key = line.split(':')[0] + value = ':'.join(line.split(':')[1:]) data_piece[key] = value else: - line_split = line.split("\t") + line_split = line.split('\t') mz.append(line_split[0]) intensity.append(line_split[1]) ion.append(line_split[2]) - data_piece["peaks"] = {"mz": mz, "intensity": intensity, "ion": ion} + data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'ion': ion} return data_piece def get_spectrum_by_name(source, name): - file = open(source, "r") + file = open(source, 'r') - line_match = "Name: " + name + "\n" + line_match = 'Name: ' + name + '\n' data_piece = {} mz, intensity, ion = [], [], [] found = False for line in file: - if line[0:5] == "Name:" and found: - data_piece["peaks"] = {"mz": mz, "intensity": intensity, "ion": ion} + if line[0:5] == 'Name:' and found: + data_piece['peaks'] = {'mz': mz, 'intensity': intensity, 'ion': ion} break if line == line_match: # exact name match found = True if not found: continue - if ":" in line: - key = line.split(":")[0] - value = line.split(":", 1)[1].strip() + if ':' in line: + key = line.split(':')[0] + value = line.split(':', 1)[1].strip() data_piece[key] = value else: - line_split = line.split("\t") + line_split = line.split('\t') mz.append(line_split[0]) intensity.append(line_split[1]) ion.append(line_split[2].strip()) diff --git a/fiora/IO/mspWriter.py b/fiora/IO/mspWriter.py index 72efbed..9cf37b3 100644 --- a/fiora/IO/mspWriter.py +++ b/fiora/IO/mspWriter.py @@ -3,30 +3,30 @@ def write_msp( path, write_header=True, headers=[ - "Name", - "Precursor_type", - "Spectrum_type", - "PRECURSORMZ", - "RETENTIONTIME", - "Charge", - "Comments", - "Num peaks", + 'Name', + 'Precursor_type', + 'Spectrum_type', + 'PRECURSORMZ', + 'RETENTIONTIME', + 'Charge', + 'Comments', + 'Num peaks', ], annotation: bool = False, ): - with open(path, "w") as outfile: + with open(path, 'w') as outfile: for x in df.index: peaks = df.loc[x].peaks if write_header: for key in headers: - outfile.write(key + ": " + str(df.loc[x][key]) + "\n") - outfile.write(f"Num peaks: {len(peaks['mz'])}\n") - for i in range(len(peaks["mz"])): - peak_annotation = f"\t{peaks['annotation'][i]}" if annotation else "" + outfile.write(key + ': ' + str(df.loc[x][key]) + '\n') + outfile.write(f'Num peaks: {len(peaks["mz"])}\n') + for i in range(len(peaks['mz'])): + peak_annotation = f'\t{peaks["annotation"][i]}' if annotation else '' outfile.write( - str(peaks["mz"][i]) - + "\t" - + str(peaks["intensity"][i]) + str(peaks['mz'][i]) + + '\t' + + str(peaks['intensity'][i]) + peak_annotation - + "\n" + + '\n' ) diff --git a/fiora/IO/mspredReader.py b/fiora/IO/mspredReader.py index 60d253e..2610893 100644 --- a/fiora/IO/mspredReader.py +++ b/fiora/IO/mspredReader.py @@ -1,38 +1,39 @@ import json -import pandas as pd import os from collections import defaultdict +import pandas as pd + # Function was adjusted using code from https://github.com/samgoldman97/ms-pred def convert_dict_to_mz(values): peak_dict = defaultdict(lambda: {}) - for k, val in values["frags"].items(): - masses, intens = val["mz_charge"], val["intens"] + for k, val in values['frags'].items(): + masses, intens = val['mz_charge'], val['intens'] for m, i in zip(masses, intens): if i <= 0: continue current_peak_object = peak_dict[m] - if current_peak_object.get("inten", 0) > 0: + if current_peak_object.get('inten', 0) > 0: # update - if current_peak_object.get("inten") < i: - current_peak_object["frag_hash"] = k - current_peak_object["inten"] += i + if current_peak_object.get('inten') < i: + current_peak_object['frag_hash'] = k + current_peak_object['inten'] += i else: - current_peak_object["inten"] = i - current_peak_object["frag_hash"] = k + current_peak_object['inten'] = i + current_peak_object['frag_hash'] = k - max_inten = max(*[i["inten"] for i in peak_dict.values()], 1e-9) + max_inten = max(*[i['inten'] for i in peak_dict.values()], 1e-9) peak_dict = { - k: dict(inten=v["inten"] / max_inten, frag_hash=v["frag_hash"]) + k: dict(inten=v['inten'] / max_inten, frag_hash=v['frag_hash']) for k, v in peak_dict.items() } - peaks = {"mz": [], "intensity": []} + peaks = {'mz': [], 'intensity': []} for k, v in peak_dict.items(): - peaks["mz"].append(k) - peaks["intensity"].append(v["inten"]) + peaks['mz'].append(k) + peaks['intensity'].append(v['inten']) return peaks @@ -41,16 +42,16 @@ def read(dir): spectra = [] for file in os.listdir(dir): - if file.endswith(".json"): - temp_dict = {"file": file, "name": file.split(".")[0].replace("pred_", "")} + if file.endswith('.json'): + temp_dict = {'file': file, 'name': file.split('.')[0].replace('pred_', '')} try: - with open(os.path.join(dir, file), "r") as fp: + with open(os.path.join(dir, file), 'r') as fp: values = json.load(fp) peaks = convert_dict_to_mz(values) - temp_dict["peaks"] = peaks + temp_dict['peaks'] = peaks spectra.append(temp_dict) except Exception as _: - print(f"Warning: unable to read {file}") + print(f'Warning: unable to read {file}') return pd.DataFrame(spectra) diff --git a/fiora/IO/mspredWriter.py b/fiora/IO/mspredWriter.py index c8896eb..78f8c65 100644 --- a/fiora/IO/mspredWriter.py +++ b/fiora/IO/mspredWriter.py @@ -1,95 +1,96 @@ import os + from fiora.MOL.constants import ADDUCT_WEIGHTS label_header = [ - "dataset", - "spec", - "name", - "ionization", - "formula", - "smiles", - "inchikey", + 'dataset', + 'spec', + 'name', + 'ionization', + 'formula', + 'smiles', + 'inchikey', ] def write_labels(df, output_file, label_map, from_metabolite=True): if from_metabolite: - df["formula"] = df["Metabolite"].apply(lambda x: x.Formula) - df["smiles"] = df["Metabolite"].apply(lambda x: x.SMILES) - df["inchikey"] = df["Metabolite"].apply(lambda x: x.InChIKey) + df['formula'] = df['Metabolite'].apply(lambda x: x.Formula) + df['smiles'] = df['Metabolite'].apply(lambda x: x.SMILES) + df['inchikey'] = df['Metabolite'].apply(lambda x: x.InChIKey) - if "dataset" not in label_map.keys(): - df = df.drop(columns="dataset") + if 'dataset' not in label_map.keys(): + df = df.drop(columns='dataset') try: df.rename(columns=label_map)[label_header].to_csv( - output_file, index=False, sep="\t" + output_file, index=False, sep='\t' ) except Exception as _: raise NameError( - f"Failed to write labels file. Make sure file path is correct. Make sure all headers are present in DataFrame {label_header}. Use label_map to rename columns." + f'Failed to write labels file. Make sure file path is correct. Make sure all headers are present in DataFrame {label_header}. Use label_map to rename columns.' ) -def write_spec_files(df, directory, spec_tag="spec"): +def write_spec_files(df, directory, spec_tag='spec'): for i, row in df.iterrows(): - output_file = os.path.join(directory, str(row[spec_tag]) + ".ms") - with open(output_file, "w") as f: - metabolite = row["Metabolite"] + output_file = os.path.join(directory, str(row[spec_tag]) + '.ms') + with open(output_file, 'w') as f: + metabolite = row['Metabolite'] # Write header - f.write(">compound " + row["Name"] + " \n") - f.write(">formula " + metabolite.Formula + " \n") + f.write('>compound ' + row['Name'] + ' \n') + f.write('>formula ' + metabolite.Formula + ' \n') f.write( - ">parentmass " - + str(metabolite.ExactMolWeight + ADDUCT_WEIGHTS[row["Precursor_type"]]) - + " \n" + '>parentmass ' + + str(metabolite.ExactMolWeight + ADDUCT_WEIGHTS[row['Precursor_type']]) + + ' \n' ) - f.write(">ionization " + row["Precursor_type"] + " \n") - f.write(">InChi " + metabolite.InChI + " \n") - f.write(">InChIKey " + metabolite.InChIKey + " \n") - f.write("#smiles " + metabolite.SMILES + " \n") - f.write("#scans " + "1" + " \n") - f.write("#_FILE " + str(row[spec_tag]) + " \n") - f.write("#spectrumid " + str(row[spec_tag]) + " \n") - f.write("#InChi " + metabolite.InChI + " \n") - f.write("\n") + f.write('>ionization ' + row['Precursor_type'] + ' \n') + f.write('>InChi ' + metabolite.InChI + ' \n') + f.write('>InChIKey ' + metabolite.InChIKey + ' \n') + f.write('#smiles ' + metabolite.SMILES + ' \n') + f.write('#scans ' + '1' + ' \n') + f.write('#_FILE ' + str(row[spec_tag]) + ' \n') + f.write('#spectrumid ' + str(row[spec_tag]) + ' \n') + f.write('#InChi ' + metabolite.InChI + ' \n') + f.write('\n') # Write peaks - f.write(">ms2peaks") - for j, mz in enumerate(row["peaks"]["mz"]): - f.write("\n") - f.write(str(mz) + " " + str(row["peaks"]["intensity"][j])) + f.write('>ms2peaks') + for j, mz in enumerate(row['peaks']['mz']): + f.write('\n') + f.write(str(mz) + ' ' + str(row['peaks']['intensity'][j])) def write_dataset( df, directory, label_map={ - "dataset": "dataset", - "spec": "spec", - "name": "name", - "formula": "formula", - "ionization": "ionization", - "smiles": "smiles", - "inchikey": "inchikey", + 'dataset': 'dataset', + 'spec': 'spec', + 'name': 'name', + 'formula': 'formula', + 'ionization': 'ionization', + 'smiles': 'smiles', + 'inchikey': 'inchikey', }, ): write_labels( df, - output_file=os.path.join(directory, "labels.tsv"), + output_file=os.path.join(directory, 'labels.tsv'), label_map=label_map, from_metabolite=True, ) write_labels( df.iloc[::-1], - output_file=os.path.join(directory, "reverse_labels.tsv"), + output_file=os.path.join(directory, 'reverse_labels.tsv'), label_map=label_map, from_metabolite=True, ) - spec_tag = {v: k for k, v in label_map.items()}["spec"] - spec_path = os.path.join(directory, "spec_files") + spec_tag = {v: k for k, v in label_map.items()}['spec'] + spec_path = os.path.join(directory, 'spec_files') if not os.path.exists(spec_path): os.mkdir(spec_path) write_spec_files(df, spec_path, spec_tag=spec_tag) diff --git a/fiora/MOL/FragmentationTree.py b/fiora/MOL/FragmentationTree.py index af87077..2908baf 100644 --- a/fiora/MOL/FragmentationTree.py +++ b/fiora/MOL/FragmentationTree.py @@ -1,12 +1,11 @@ -from fiora.MOL.mol_graph import mol_to_graph, get_adjacency_matrix, get_edges -from fiora.MS.ms_utility import do_mz_values_match -import fiora.MOL.constants as constants - from rdkit import Chem from rdkit.Chem import AllChem - from treelib import Tree +import fiora.MOL.constants as constants +from fiora.MOL.mol_graph import get_adjacency_matrix, get_edges, mol_to_graph +from fiora.MS.ms_utility import do_mz_values_match + # TODO can a fragment be tied to more than one edge: Yes. TODO see todo case in build_frag_tree @@ -20,15 +19,15 @@ def __init__(self, mol, edge=None, isotope_labels=None): a.GetIsotope() for a in mol.GetAtoms() ] # use isotope info as a proxy for node id break_side = ( - "left" + 'left' if edge[0] in subgraph - else "right" + else 'right' if edge[1] in subgraph - else "unidentified" + else 'unidentified' ) - if break_side == "unidentified": - print("ERROR", edge, subgraph, Chem.MolToSmiles(mol)) - raise ValueError("Unidentified edge in fragment") + if break_side == 'unidentified': + print('ERROR', edge, subgraph, Chem.MolToSmiles(mol)) + raise ValueError('Unidentified edge in fragment') self.break_sides = [break_side] self.subgraphs = [subgraph] else: @@ -51,8 +50,8 @@ def __init__(self, mol, edge=None, isotope_labels=None): } self.mz.update( { - mode.replace("]+", "]-"): self.neutral_mass - + constants.ADDUCT_WEIGHTS[mode.replace("]+", "]-")] + mode.replace(']+', ']-'): self.neutral_mass + + constants.ADDUCT_WEIGHTS[mode.replace(']+', ']-')] for mode in self.modes } ) @@ -63,10 +62,10 @@ def __eq__(self, __o: object) -> bool: return self.get_morganFinger() == __o.get_morganFinger() def __repr__(self): - return " :: " + self.smiles # + " " + str(self.mz) + return ' :: ' + self.smiles # + " " + str(self.mz) def __str__(self): - return " :: " + self.smiles # + " " + str(self.mz) + return ' :: ' + self.smiles # + " " + str(self.mz) def num_of_edges(self): return len(self.edges) @@ -100,16 +99,16 @@ def __init__(self, root_mol): self.edge_map = {None: Fragment(root_mol)} self.patt = Chem.MolFromSmarts( - "[!$([NH]!@C(=O))&!D1&!$(*#*)]-&!@[!$([NH]!@C(=O))&!D1&!$(*#*)]" + '[!$([NH]!@C(=O))&!D1&!$(*#*)]-&!@[!$([NH]!@C(=O))&!D1&!$(*#*)]' ) def __repr__(self): self.fragmentation_tree.show(idhidden=False) - return "" + return '' def __str__(self): self.fragmentation_tree.show(idhidden=False) - return "" + return '' """ Getter @@ -242,17 +241,17 @@ def match_peak_list(self, mz_list, int_list=None, tolerance=None): does_match, frag_ion = frag.match_peak(mz, tolerance=tolerance) if does_match: if was_peak_matched_already: - matches[mz]["fragments"] += [ + matches[mz]['fragments'] += [ frag ] # Report fragment for each edge leading to it - matches[mz]["ion_modes"] += [frag_ion] + matches[mz]['ion_modes'] += [frag_ion] else: matches[mz] = { - "intensity": int_list[i] if int_list else None, - "fragments": [ + 'intensity': int_list[i] if int_list else None, + 'fragments': [ frag ], # Report fragment for each edge leading to it - "ion_modes": [frag_ion], + 'ion_modes': [frag_ion], } was_peak_matched_already = True @@ -261,12 +260,12 @@ def match_peak_list(self, mz_list, int_list=None, tolerance=None): # return matches sum_intensity = sum( - [m["intensity"] for mz, m in matches.items() if m["intensity"] is not None] + [m['intensity'] for mz, m in matches.items() if m['intensity'] is not None] ) if sum_intensity > 0: - for mz in matches.keys(): - int_value = matches[mz]["intensity"] - matches[mz]["relative_intensity"] = ( + for mz, match in matches.items(): + int_value = match['intensity'] + match['relative_intensity'] = ( int_value / sum_intensity ) # only considered matched peaks @@ -372,13 +371,13 @@ def break_bond(self, mol, i, j, add_dummy_atoms=False): em.RemoveBond(i, j) if add_dummy_atoms: - em.AddAtom(Chem.Atom(0)) # - em.AddBond(i, num_atoms, Chem.BondType.SINGLE) # - em.AddAtom(Chem.Atom(0)) # - em.AddBond(j, num_atoms + 1, Chem.BondType.SINGLE) # + em.AddAtom(Chem.Atom(0)) + em.AddBond(i, num_atoms, Chem.BondType.SINGLE) + em.AddAtom(Chem.Atom(0)) + em.AddBond(j, num_atoms + 1, Chem.BondType.SINGLE) new_mol = em.GetMol() - Chem.SanitizeMol(new_mol) # + Chem.SanitizeMol(new_mol) frags = Chem.GetMolFrags(new_mol, asMols=True) return new_mol, frags diff --git a/fiora/MOL/Metabolite.py b/fiora/MOL/Metabolite.py index e05f5ab..b280001 100644 --- a/fiora/MOL/Metabolite.py +++ b/fiora/MOL/Metabolite.py @@ -1,37 +1,33 @@ import sys -import numpy as np -import torch import warnings -import matplotlib.pyplot as plt from typing import Literal -from rdkit import Chem -from rdkit.Chem import AllChem -from rdkit.Chem import Draw -from rdkit.Chem import Descriptors -from rdkit.Chem import rdMolDescriptors -from rdkit import DataStructs -from rdkit.Chem.Draw import rdMolDraw2D + +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +import torch from IPython.display import SVG, display +from rdkit import Chem, DataStructs +from rdkit.Chem import AllChem, Descriptors, Draw, rdMolDescriptors +from rdkit.Chem.Draw import rdMolDraw2D from torch_geometric.data import Data -import networkx as nx - +from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder +from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder +from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder from fiora.MOL.constants import ( - DEFAULT_PPM, - DEFAULT_MODE_MAP, ADDUCT_WEIGHTS, - ORDERED_ELEMENT_LIST_WITH_HYDROGEN, + DEFAULT_MODE_MAP, + DEFAULT_PPM, MAX_SUBGRAPH_NODES, + ORDERED_ELEMENT_LIST_WITH_HYDROGEN, ) +from fiora.MOL.FragmentationTree import FragmentationTree from fiora.MOL.mol_graph import ( - mol_to_graph, get_adjacency_matrix, get_edges, + mol_to_graph, ) -from fiora.MOL.FragmentationTree import FragmentationTree -from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder -from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder -from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder class Metabolite: @@ -43,7 +39,7 @@ def __init__( self.MOL = Chem.MolFromSmiles(self.SMILES) if not self.MOL: raise AssertionError( - "Molecule invalid; could not be generated from SMILES" + 'Molecule invalid; could not be generated from SMILES' ) self.InChI = Chem.MolToInchi(self.MOL) self.InChIKey = Chem.InchiToInchiKey(self.InChI) @@ -52,12 +48,12 @@ def __init__( self.MOL = Chem.MolFromInchi(self.InChI) if not self.MOL: raise AssertionError( - "Molecule invalid; could not be generated from InChI" + 'Molecule invalid; could not be generated from InChI' ) self.InChIKey = Chem.InchiToInchiKey(self.InChI) self.SMILES = Chem.MolToSmiles(self.MOL) else: - raise ValueError("Neither SMILES nor InChI were specified.") + raise ValueError('Neither SMILES nor InChI were specified.') self.ExactMolWeight = Descriptors.ExactMolWt(self.MOL) self.Formula = rdMolDescriptors.CalcMolFormula(self.MOL) @@ -72,10 +68,10 @@ def __init__( self.loss_weight = 1.0 def __repr__(self): - return f"" + return f'' def __str__(self): - return f"" + return f'' def __eq__(self, __o: object) -> bool: if self.ExactMolWeight != __o.ExactMolWeight: @@ -87,7 +83,7 @@ def __eq__(self, __o: object) -> bool: def __lt__(self, __o: object) -> bool: # TODO not tested!s warnings.warn( - "Warning: < operation for Metabolite class is not tested. Potentially flawed." + 'Warning: < operation for Metabolite class is not tested. Potentially flawed.' ) if self.ExactMolWeight < __o.ExactMolWeight: return True @@ -109,11 +105,11 @@ def set_loss_weight(self, weight): def get_theoretical_precursor_mz(self, ion_type: str = None): if ion_type is None: - if hasattr(self, "metadata") and "precursor_mode" in self.metadata: - ion_type = self.metadata["precursor_mode"] + if hasattr(self, 'metadata') and 'precursor_mode' in self.metadata: + ion_type = self.metadata['precursor_mode'] else: raise ValueError( - "Ion type is not specified and no precursor_mode found in metadata." + 'Ion type is not specified and no precursor_mode found in metadata.' ) return self.ExactMolWeight + ADDUCT_WEIGHTS[ion_type] @@ -121,16 +117,16 @@ def get_morganFinger(self): return self.morganFinger def tanimoto_similarity( - self, __o: object, finger: Literal["morgan2", "morgan3"] = "morgan2" + self, __o: object, finger: Literal['morgan2', 'morgan3'] = 'morgan2' ): - if finger == "morgan2": + if finger == 'morgan2': return DataStructs.TanimotoSimilarity( self.get_morganFinger(), __o.get_morganFinger() ) - if finger == "morgan3": + if finger == 'morgan3': return DataStructs.TanimotoSimilarity(self.morganFinger3, __o.morganFinger3) raise ValueError( - f"Unknown type of fingerprint: {finger}. Cannot compare Metabolites." + f'Unknown type of fingerprint: {finger}. Cannot compare Metabolites.' ) def draw(self, ax=plt, show: bool = False, high_res: bool = False): @@ -150,14 +146,14 @@ def draw(self, ax=plt, show: bool = False, high_res: bool = False): img = Draw.MolToImage(self.MOL, ax=ax) ax.grid(False) ax.tick_params( - axis="both", + axis='both', bottom=False, labelbottom=False, left=False, labelleft=False, ) ax.imshow(img) - ax.axis("off") + ax.axis('off') if show: plt.show() return img @@ -188,26 +184,26 @@ def compute_graph_attributes( # Labels self.is_node_aromatic = torch.tensor( - [[self.Graph.nodes[atom]["is_aromatic"] for atom in self.Graph.nodes()]], + [[self.Graph.nodes[atom]['is_aromatic'] for atom in self.Graph.nodes()]], dtype=torch.float32, ).t() self.is_edge_aromatic = torch.tensor( [ [ - self.Graph[u][v]["bond_type"].name == "AROMATIC" + self.Graph[u][v]['bond_type'].name == 'AROMATIC' for u, v in self.edges_as_tuples ] ], dtype=torch.float32, ).t() self.is_edge_in_ring = torch.tensor( - [[self.Graph[u][v]["bond"].IsInRing() for u, v in self.edges_as_tuples]], + [[self.Graph[u][v]['bond'].IsInRing() for u, v in self.edges_as_tuples]], dtype=torch.float32, ).t() self.is_edge_not_in_ring = torch.tensor( [ [ - not self.Graph[u][v]["bond"].IsInRing() + not self.Graph[u][v]['bond'].IsInRing() for u, v in self.edges_as_tuples ] ], @@ -224,35 +220,35 @@ def compute_graph_attributes( # Lists if not memory_safe: self.atoms_in_order = [ - self.Graph.nodes[atom]["atom"] for atom in self.Graph.nodes() + self.Graph.nodes[atom]['atom'] for atom in self.Graph.nodes() ] self.node_elements = [ - self.Graph.nodes[atom]["atom"].GetSymbol() + self.Graph.nodes[atom]['atom'].GetSymbol() for atom in self.Graph.nodes() ] self.edge_bond_names = [ - self.Graph[u][v]["bond_type"].name for u, v in self.edges_as_tuples + self.Graph[u][v]['bond_type'].name for u, v in self.edges_as_tuples ] # Features if node_encoder: - self.node_features = node_encoder.encode(self.Graph, encoder_type="number") + self.node_features = node_encoder.encode(self.Graph, encoder_type='number') self.node_features_one_hot = node_encoder.encode( - self.Graph, encoder_type="one_hot" + self.Graph, encoder_type='one_hot' ) if bond_encoder: self.edge_bond_types = torch.tensor( [ - bond_encoder.number_mapper["bond_type"][bond_name] + bond_encoder.number_mapper['bond_type'][bond_name] for bond_name in self.edge_bond_names ], dtype=torch.int64, ) self.bond_features = bond_encoder.encode( - self.Graph, self.edges_as_tuples, encoder_type="number" + self.Graph, self.edges_as_tuples, encoder_type='number' ) self.bond_features_one_hot = bond_encoder.encode( - self.Graph, self.edges_as_tuples, encoder_type="one_hot" + self.Graph, self.edges_as_tuples, encoder_type='one_hot' ) else: self.bond_features = torch.zeros( @@ -268,7 +264,7 @@ def add_metadata( max_RT=30.0, ): self.metadata = metadata - mol_metadata = {"molecular_weight": self.ExactMolWeight} + mol_metadata = {'molecular_weight': self.ExactMolWeight} metadata.update(mol_metadata) if not process_metadata: return @@ -278,13 +274,13 @@ def add_metadata( self.setup_features_per_edge = covariate_encoder.encode( len(self.edges_as_tuples), metadata, G=self.Graph ) - if "ce_steps" in metadata: + if 'ce_steps' in metadata: self.ce_steps = torch.tensor( [ covariate_encoder.normalize_collision_steps( - metadata["ce_steps"] + metadata['ce_steps'] ) - + [np.nan for _ in range(7 - len(metadata["ce_steps"]))] + + [np.nan for _ in range(7 - len(metadata['ce_steps']))] ] ) # nan padding else: @@ -292,7 +288,7 @@ def add_metadata( 0 ) self.ce_idx = torch.tensor( - covariate_encoder.one_hot_mapper["collision_energy"], dtype=int + covariate_encoder.one_hot_mapper['collision_energy'], dtype=int ).unsqueeze(dim=-1) else: self.setup_features = torch.zeros(1, 0, dtype=torch.float32) @@ -305,34 +301,34 @@ def add_metadata( 1, metadata, G=self.Graph ) - if "retention_time" in metadata.keys(): + if 'retention_time' in metadata.keys(): if ( - not metadata["retention_time"] - or np.isnan(metadata["retention_time"]) - or "GC" in str(metadata["instrument"]) - or metadata["retention_time"] > max_RT + not metadata['retention_time'] + or np.isnan(metadata['retention_time']) + or 'GC' in str(metadata['instrument']) + or metadata['retention_time'] > max_RT ): - metadata["retention_time"] = np.nan + metadata['retention_time'] = np.nan self.rt = torch.tensor([np.nan]).unsqueeze(dim=-1) self.rt_mask = torch.tensor([0], dtype=torch.bool).unsqueeze(dim=-1) else: - self.rt = torch.tensor([metadata["retention_time"]]).unsqueeze(dim=-1) + self.rt = torch.tensor([metadata['retention_time']]).unsqueeze(dim=-1) self.rt_mask = torch.tensor([1], dtype=torch.bool).unsqueeze(dim=-1) else: self.rt = torch.tensor([torch.nan]).unsqueeze(dim=-1) self.rt_mask = torch.tensor([0], dtype=torch.bool).unsqueeze(dim=-1) - if "ccs" in metadata.keys(): + if 'ccs' in metadata.keys(): if ( - not metadata["ccs"] - or np.isnan(metadata["ccs"]) - or "GC" in str(metadata["instrument"]) + not metadata['ccs'] + or np.isnan(metadata['ccs']) + or 'GC' in str(metadata['instrument']) ): - metadata["ccs"] = np.nan + metadata['ccs'] = np.nan self.ccs_mask = torch.tensor([0], dtype=torch.bool).unsqueeze(dim=-1) self.ccs = torch.tensor([np.nan]).unsqueeze(dim=-1) else: - self.ccs = torch.tensor([metadata["ccs"]]).unsqueeze(dim=-1) + self.ccs = torch.tensor([metadata['ccs']]).unsqueeze(dim=-1) self.ccs_mask = torch.tensor([1], dtype=torch.bool).unsqueeze(dim=-1) else: self.ccs = torch.tensor([torch.nan]).unsqueeze(dim=-1) @@ -391,14 +387,14 @@ def extract_subgraph_features_from_edges(self) -> None: if (u, v) in edge_map: frag_list = edge_map[(u, v)] if frag_list != {}: - left_fragment = frag_list["left"] - right_fragment = frag_list["right"] + left_fragment = frag_list['left'] + right_fragment = frag_list['right'] else: if (v, u) in edge_map: frag_list = edge_map[(v, u)] if frag_list != {}: - left_fragment = frag_list["right"] - right_fragment = frag_list["left"] + left_fragment = frag_list['right'] + right_fragment = frag_list['left'] # Initialize element composition for the edge edge_elem_comp = torch.zeros( @@ -435,7 +431,7 @@ def extract_subgraph_features_from_edges(self) -> None: or len(right_nodes) > MAX_SUBGRAPH_NODES ): warnings.warn( - f"Metabolite {self.SMILES}: Subgraph size ({max(len(left_nodes), len(right_nodes))}) exceeds MAX_SUBGRAPH_NODES ({MAX_SUBGRAPH_NODES}). Truncating." + f'Metabolite {self.SMILES}: Subgraph size ({max(len(left_nodes), len(right_nodes))}) exceeds MAX_SUBGRAPH_NODES ({MAX_SUBGRAPH_NODES}). Truncating.' ) len_left = min(len(left_nodes), MAX_SUBGRAPH_NODES) @@ -450,7 +446,7 @@ def extract_subgraph_features_from_edges(self) -> None: @staticmethod def _edge_count_cols(mode_map, mode_count, ion_mode, break_side): base_col = mode_map[ion_mode] - if break_side == "left": + if break_side == 'left': return base_col, base_col + mode_count return base_col + mode_count, base_col @@ -468,7 +464,7 @@ def match_fragments_to_peaks( self.edge_breaks = [ frag.edges for mz in self.peak_matches.keys() - for frag in self.peak_matches[mz]["fragments"] + for frag in self.peak_matches[mz]['fragments'] ] self.edge_breaks = [ e for edges in self.edge_breaks for e in edges @@ -493,21 +489,21 @@ def match_fragments_to_peaks( # Flatten out all edges from fragments self.edge_intensities = [] for mz in self.peak_matches.keys(): - intensity = self.peak_matches[mz]["intensity"] / sum( - f.num_of_edges() for f in self.peak_matches[mz]["fragments"] + intensity = self.peak_matches[mz]['intensity'] / sum( + f.num_of_edges() for f in self.peak_matches[mz]['fragments'] ) - self.peak_matches[mz]["edges"] = [ - e for f in self.peak_matches[mz]["fragments"] for e in f.edges + self.peak_matches[mz]['edges'] = [ + e for f in self.peak_matches[mz]['fragments'] for e in f.edges ] - for i, f in enumerate(self.peak_matches[mz]["fragments"]): + for i, f in enumerate(self.peak_matches[mz]['fragments']): for j, edge in enumerate(f.edges): entry = ( edge, { - "intensity": intensity, - "fragment": f, - "break_side": f.break_sides[j], - "ion_mode": self.peak_matches[mz]["ion_modes"][i][0], + 'intensity': intensity, + 'fragment': f, + 'break_side': f.break_sides[j], + 'ion_mode': self.peak_matches[mz]['ion_modes'][i][0], }, ) @@ -531,7 +527,7 @@ def match_fragments_to_peaks( # Determining edge break probabilites from peak intensities. Multiple edges for the same fragment -> divide by number of edges. Multiple fragments from edge -> add intensities. for edge, values in self.edge_intensities: if edge is None: # precursor - self.precursor_count += values["intensity"] + self.precursor_count += values['intensity'] continue edge_index = ( torch.logical_or( @@ -542,7 +538,7 @@ def match_fragments_to_peaks( .nonzero() .squeeze() ) - self.edge_break_count[edge_index] += values["intensity"] + self.edge_break_count[edge_index] += values['intensity'] forward_idx = ( ((torch.tensor(edge) == self.edges).sum(dim=1) == 2).nonzero().squeeze() @@ -553,10 +549,10 @@ def match_fragments_to_peaks( .squeeze() ) forward_col, backward_col = get_edge_count_cols( - mode_map, mode_count, values["ion_mode"], values["break_side"] + mode_map, mode_count, values['ion_mode'], values['break_side'] ) - self.edge_count_matrix[forward_idx, forward_col] += values["intensity"] - self.edge_count_matrix[backward_idx, backward_col] += values["intensity"] + self.edge_count_matrix[forward_idx, forward_col] += values['intensity'] + self.edge_count_matrix[backward_idx, backward_col] += values['intensity'] # "bond_features_one_hot", # Compile probability vectors @@ -598,66 +594,66 @@ def match_fragments_to_peaks( max_intensity = max(int_list) intensity_filter_threshold = 0.01 self.match_stats = { - "counts": self.compiled_countsALL.sum().tolist() + 'counts': self.compiled_countsALL.sum().tolist() / 2.0, # self.compiled_counts.sum().tolist() / 2.0, - "ms_all_counts": sum(int_list), - "coverage": (self.compiled_countsALL.sum().tolist() / 2.0) / sum(int_list), - "coverage_wo_prec": (self.edge_break_count.sum().tolist() / 2.0) + 'ms_all_counts': sum(int_list), + 'coverage': (self.compiled_countsALL.sum().tolist() / 2.0) / sum(int_list), + 'coverage_wo_prec': (self.edge_break_count.sum().tolist() / 2.0) / (sum(int_list) - self.precursor_count.tolist()), - "precursor_prob": self.precursor_count.tolist() + 'precursor_prob': self.precursor_count.tolist() / (self.compiled_countsALL.sum().tolist() / 2.0) if (self.compiled_countsALL.sum().tolist() / 2.0) > 0 else 0.0, - "precursor_raw_prob": self.precursor_count.tolist() / sum(int_list), - "num_peaks": len(mz_fragments), - "num_peak_matches": len(self.peak_matches), - "percent_peak_matches": len(self.peak_matches) / len(mz_fragments), - "num_peaks_filtered": sum( + 'precursor_raw_prob': self.precursor_count.tolist() / sum(int_list), + 'num_peaks': len(mz_fragments), + 'num_peak_matches': len(self.peak_matches), + 'percent_peak_matches': len(self.peak_matches) / len(mz_fragments), + 'num_peaks_filtered': sum( [(i / max_intensity) > intensity_filter_threshold for i in int_list] ), - "num_peak_matches_filtered": sum( + 'num_peak_matches_filtered': sum( [ - match["relative_intensity"] > intensity_filter_threshold + match['relative_intensity'] > intensity_filter_threshold for mz, match in self.peak_matches.items() ] ), - "percent_peak_matches_filtered": sum( + 'percent_peak_matches_filtered': sum( [ - match["relative_intensity"] > intensity_filter_threshold + match['relative_intensity'] > intensity_filter_threshold for mz, match in self.peak_matches.items() ] ) / len(mz_fragments), - "num_non_precursor_matches": sum( + 'num_non_precursor_matches': sum( [ - (None not in match["edges"]) + (None not in match['edges']) for mz, match in self.peak_matches.items() ] ), - "num_peak_match_conflicts": sum( - [len(match["edges"]) > 1 for mz, match in self.peak_matches.items()] + 'num_peak_match_conflicts': sum( + [len(match['edges']) > 1 for mz, match in self.peak_matches.items()] ), - "num_fragment_conflicts": sum( - [len(match["fragments"]) > 1 for mz, match in self.peak_matches.items()] + 'num_fragment_conflicts': sum( + [len(match['fragments']) > 1 for mz, match in self.peak_matches.items()] ), - "rel_fragment_conflicts": sum( - [len(match["fragments"]) > 1 for mz, match in self.peak_matches.items()] + 'rel_fragment_conflicts': sum( + [len(match['fragments']) > 1 for mz, match in self.peak_matches.items()] ) / sum( [ - (None not in match["edges"]) + (None not in match['edges']) for mz, match in self.peak_matches.items() ] ) if sum( [ - (None not in match["edges"]) + (None not in match['edges']) for mz, match in self.peak_matches.items() ] ) > 0 else 0, - "ms_num_all_peaks": len(mz_fragments), + 'ms_num_all_peaks': len(mz_fragments), } if match_stats_only: @@ -669,32 +665,32 @@ def get_memory_usage(self): } total_size = sum(memory_usage.values()) return { - "attributes": dict( + 'attributes': dict( sorted(memory_usage.items(), key=lambda x: x[1], reverse=True) ), - "total_size": total_size, + 'total_size': total_size, } def free_memory(self): attributes_to_free = [ - "edge_break_count", - "precursor_count", - "precursor_prob", - "precursor_sqrt_prob", - "edge_count_matrix", - "compiled_countsALL", - "compiled_probsALL", - "compiled_countsSQRT", - "compiled_probsSQRT", - "compiled_validation_maskALL", - "edge_breaks", - "edge_intensities", - "setup_features", - "setup_features_per_edge", - "node_features", - "node_features_one_hot", - "bond_features", - "bond_features_one_hot", + 'edge_break_count', + 'precursor_count', + 'precursor_prob', + 'precursor_sqrt_prob', + 'edge_count_matrix', + 'compiled_countsALL', + 'compiled_probsALL', + 'compiled_countsSQRT', + 'compiled_probsSQRT', + 'compiled_validation_maskALL', + 'edge_breaks', + 'edge_intensities', + 'setup_features', + 'setup_features_per_edge', + 'node_features', + 'node_features_one_hot', + 'bond_features', + 'bond_features_one_hot', ] # Tensors from peak matching for attr in attributes_to_free: diff --git a/fiora/MOL/MetaboliteDatasetStatistics.py b/fiora/MOL/MetaboliteDatasetStatistics.py index cb0f099..6561303 100644 --- a/fiora/MOL/MetaboliteDatasetStatistics.py +++ b/fiora/MOL/MetaboliteDatasetStatistics.py @@ -1,7 +1,9 @@ -import pandas as pd from collections import Counter -from fiora.MOL.Metabolite import Metabolite + +import pandas as pd + from fiora.MOL.constants import ORDERED_ELEMENT_LIST +from fiora.MOL.Metabolite import Metabolite class MetaboliteDatasetStatistics: @@ -21,25 +23,25 @@ def _compute_element_composition_stats(self): """ all_meta = [] for _, row in self.data.iterrows(): - metabolite = row["Metabolite"] + metabolite = row['Metabolite'] if isinstance(metabolite, Metabolite): # Create nested dictionaries for presence and count metadata_dict = { - "element_presence": { + 'element_presence': { element: int(element in metabolite.node_elements) for element in ORDERED_ELEMENT_LIST }, - "element_count": { + 'element_count': { element: metabolite.node_elements.count(element) for element in ORDERED_ELEMENT_LIST }, } # Add additional metadata - metadata_dict["ExactMolWeight"] = metabolite.ExactMolWeight - metadata_dict["Formula"] = metabolite.Formula - metadata_dict["SMILES"] = metabolite.SMILES - metadata_dict["InChIKey"] = metabolite.InChIKey - metadata_dict["TotalElements"] = len( + metadata_dict['ExactMolWeight'] = metabolite.ExactMolWeight + metadata_dict['Formula'] = metabolite.Formula + metadata_dict['SMILES'] = metabolite.SMILES + metadata_dict['InChIKey'] = metabolite.InChIKey + metadata_dict['TotalElements'] = len( metabolite.node_elements ) # Total number of elements all_meta.append(metadata_dict) @@ -51,7 +53,7 @@ def _compute_element_summary(self): Compute total counts, presence probability for each element, and ANY_RARE probability across the entire dataset. :return: dict with total counts, presence probabilities for each element, and ANY_RARE probability. """ - individual_stats = self.statistics["Individual_molecular_stats"] + individual_stats = self.statistics['Individual_molecular_stats'] # Initialize counters total_counts = Counter() @@ -62,13 +64,13 @@ def _compute_element_summary(self): rare_elements = [ element for element in ORDERED_ELEMENT_LIST - if element not in ["C", "O", "N", "H"] + if element not in ['C', 'O', 'N', 'H'] ] # Aggregate counts and presence probabilities for _, row in individual_stats.iterrows(): - element_counts = row["element_count"] - element_presence = row["element_presence"] + element_counts = row['element_count'] + element_presence = row['element_presence'] total_counts.update(element_counts) presence_counts.update(element_presence) @@ -84,11 +86,11 @@ def _compute_element_summary(self): } # Compute ANY_RARE probability and add it as another "element" - presence_probabilities["ANY_RARE"] = any_rare_count / total_molecules + presence_probabilities['ANY_RARE'] = any_rare_count / total_molecules return { - "Total Counts": total_counts, - "Presence Probabilities": presence_probabilities, + 'Total Counts': total_counts, + 'Presence Probabilities': presence_probabilities, } def generate_molecular_statistics(self, unique_compounds: bool = True): @@ -97,19 +99,19 @@ def generate_molecular_statistics(self, unique_compounds: bool = True): """ # Retrieve detailed information for each metabolite if unique_compounds: - self.data = self.data.drop_duplicates(subset="group_id") - self.statistics["Individual_molecular_stats"] = ( + self.data = self.data.drop_duplicates(subset='group_id') + self.statistics['Individual_molecular_stats'] = ( self._compute_element_composition_stats() ) - self.statistics["Molecular Summary"] = self._compute_element_summary() + self.statistics['Molecular Summary'] = self._compute_element_summary() def _compute_duplicates(self): """ Compute duplicate occurrences based on 'group_id'. :return: pd.DataFrame with group_id counts. """ - group_counts = self.data["group_id"].value_counts().reset_index() - group_counts.columns = ["group_id", "Count"] + group_counts = self.data['group_id'].value_counts().reset_index() + group_counts.columns = ['group_id', 'Count'] return group_counts def get_statistics(self): @@ -119,6 +121,6 @@ def get_statistics(self): """ if not self.statistics: raise ValueError( - "Statistics have not been generated yet. Call generate_molecular_statistics() first." + 'Statistics have not been generated yet. Call generate_molecular_statistics() first.' ) return self.statistics diff --git a/fiora/MOL/MetaboliteIndex.py b/fiora/MOL/MetaboliteIndex.py index 42830f3..ae66867 100644 --- a/fiora/MOL/MetaboliteIndex.py +++ b/fiora/MOL/MetaboliteIndex.py @@ -1,6 +1,7 @@ from typing import List, Literal -from fiora.MOL.Metabolite import Metabolite + from fiora.MOL.FragmentationTree import FragmentationTree +from fiora.MOL.Metabolite import Metabolite class MetaboliteIndex: @@ -14,21 +15,21 @@ def index_metabolites(self, list_of_metabolites: List) -> None: metabolite.set_id(id) else: new_id = len(self.metabolite_index) - self.metabolite_index[new_id] = {"Metabolite": metabolite} + self.metabolite_index[new_id] = {'Metabolite': metabolite} metabolite.set_id(new_id) def create_fragmentation_trees(self, depth: int = 1) -> None: for id, entry in self.metabolite_index.items(): - metabolite = entry["Metabolite"] - entry["FragmentationTree"] = FragmentationTree(metabolite.MOL) - entry["FragmentationTree"].build_fragmentation_tree( + metabolite = entry['Metabolite'] + entry['FragmentationTree'] = FragmentationTree(metabolite.MOL) + entry['FragmentationTree'].build_fragmentation_tree( metabolite.MOL, metabolite.edges_as_tuples, depth=depth ) def add_fragmentation_trees_to_metabolite_list( self, list_of_metabolites: List[Metabolite], - graph_mismatch_policy: Literal["ignore", "recompute"] = "recompute", + graph_mismatch_policy: Literal['ignore', 'recompute'] = 'recompute', ) -> None: list_of_mismatched_ids = [] @@ -38,17 +39,17 @@ def add_fragmentation_trees_to_metabolite_list( # Check if metabolite edges align with the index if ( metabolite.edges_as_tuples - == self.metabolite_index[id]["Metabolite"].edges_as_tuples + == self.metabolite_index[id]['Metabolite'].edges_as_tuples ): metabolite.add_fragmentation_tree( - self.metabolite_index[id]["FragmentationTree"] + self.metabolite_index[id]['FragmentationTree'] ) else: - if graph_mismatch_policy == "recompute": + if graph_mismatch_policy == 'recompute': metabolite.fragment_MOL() - elif graph_mismatch_policy == "ignore": + elif graph_mismatch_policy == 'ignore': metabolite.add_fragmentation_tree( - self.metabolite_index[id]["FragmentationTree"] + self.metabolite_index[id]['FragmentationTree'] ) else: raise ValueError( @@ -60,7 +61,7 @@ def add_fragmentation_trees_to_metabolite_list( def find_metabolite_id(self, metabolite: Metabolite) -> int: for id, entry in self.metabolite_index.items(): - if metabolite == entry["Metabolite"]: + if metabolite == entry['Metabolite']: return id return None @@ -68,7 +69,7 @@ def get_metabolite(self, id: int) -> Metabolite: return self.metabolite_index[id] def get_fragmentation_tree(self, id: int) -> FragmentationTree: - return self.metabolite_index[id]["FragmentationTree"] + return self.metabolite_index[id]['FragmentationTree'] def get_number_of_metabolites(self) -> int: return len(self.metabolite_index) diff --git a/fiora/MOL/collision_energy.py b/fiora/MOL/collision_energy.py index 95cc739..20bfdbb 100644 --- a/fiora/MOL/collision_energy.py +++ b/fiora/MOL/collision_energy.py @@ -1,10 +1,10 @@ charge_factor = {1: 1, 2: 0.9, 3: 0.85, 4: 0.8, 5: 0.75} nce_instruments = [ - "Orbitrap", - "LC-ESI-QFT", - "LC-APCI-ITFT", - "Linear Ion Trap", - "LC-ESI-ITFT", + 'Orbitrap', + 'LC-ESI-QFT', + 'LC-APCI-ITFT', + 'Linear Ion Trap', + 'LC-ESI-ITFT', ] # "Flow-injection QqQ/MS", @@ -19,58 +19,58 @@ def align_CE(ce, precursor_mz, instrument=None): if instrument in nce_instruments: return NCE_to_eV(ce, precursor_mz) return ce - if "keV" in ce: - ce = ce.replace("keV", "") + if 'keV' in ce: + ce = ce.replace('keV', '') return float(ce) * 1000 - if "eV" in ce: - ce = ce.replace("eV", "") + if 'eV' in ce: + ce = ce.replace('eV', '') try: return float(ce) except Exception as _: return ce - elif "V" in ce: - ce = ce.replace("V", "") + elif 'V' in ce: + ce = ce.replace('V', '') try: return float(ce) except Exception as _: return ce - elif "ev" in ce: - ce = ce.replace("ev", "") + elif 'ev' in ce: + ce = ce.replace('ev', '') try: return float(ce) except Exception as _: return ce - elif "% (nominal)" in ce: + elif '% (nominal)' in ce: try: - nce = ce.split("% (nominal)")[0].strip().split(" ")[-1] + nce = ce.split('% (nominal)')[0].strip().split(' ')[-1] nce = float(nce) return NCE_to_eV(nce, precursor_mz) except Exception as _: return ce - elif "(nominal)" in ce: + elif '(nominal)' in ce: try: - nce = ce.split("(nominal)")[0].strip().split(" ")[-1] + nce = ce.split('(nominal)')[0].strip().split(' ')[-1] nce = float(nce) return NCE_to_eV(nce, precursor_mz) except Exception as _: return ce - elif "(NCE)" in ce: + elif '(NCE)' in ce: try: - nce = ce.strip().split("(NCE)")[0] + nce = ce.strip().split('(NCE)')[0] nce = float(nce) return NCE_to_eV(nce, precursor_mz) except Exception as _: return ce - elif "HCD" in ce: + elif 'HCD' in ce: try: - nce = ce.strip().split("HCD")[0] + nce = ce.strip().split('HCD')[0] nce = float(nce) return NCE_to_eV(nce, precursor_mz) except Exception as _: return ce - elif "%" in ce: + elif '%' in ce: try: - nce = ce.split("%")[0].strip().split(" ")[-1] + nce = ce.split('%')[0].strip().split(' ')[-1] nce = float(nce) return NCE_to_eV(nce, precursor_mz) except Exception as _: diff --git a/fiora/MOL/constants.py b/fiora/MOL/constants.py index 49c1cfe..c32fbcf 100644 --- a/fiora/MOL/constants.py +++ b/fiora/MOL/constants.py @@ -1,24 +1,24 @@ from rdkit import Chem from rdkit.Chem import Descriptors -h_minus = Chem.MolFromSmiles("[H-]") # hydrid -h_plus = Chem.MolFromSmiles("[H+]") # h proton -h_2 = Chem.MolFromSmiles("[HH]") # h2 +h_minus = Chem.MolFromSmiles('[H-]') # hydrid +h_plus = Chem.MolFromSmiles('[H+]') # h proton +h_2 = Chem.MolFromSmiles('[HH]') # h2 ADDUCT_WEIGHTS = { - "[M+H]+": Descriptors.ExactMolWt(h_plus), # 1.007276, - "[M+H]-": Descriptors.ExactMolWt(h_plus), # TODO might not technically exist - "[M+NH4]+": 18.033823, - "[M+Na]+": 22.989218, - "[M-H]-": -1 * Descriptors.ExactMolWt(h_plus), + '[M+H]+': Descriptors.ExactMolWt(h_plus), # 1.007276, + '[M+H]-': Descriptors.ExactMolWt(h_plus), # TODO might not technically exist + '[M+NH4]+': 18.033823, + '[M+Na]+': 22.989218, + '[M-H]-': -1 * Descriptors.ExactMolWt(h_plus), # # positvie fragment rearrangements # - "[M-H]+": -1 + '[M-H]+': -1 * Descriptors.ExactMolWt(h_minus), # Double bond replacing 2 hydrogen atoms + H - "[M]+": 0, - "[M-2H]+": -1 * Descriptors.ExactMolWt(h_2), # Loosing proton and hydrid - "[M-3H]+": -1 * Descriptors.ExactMolWt(h_2) + '[M]+': 0, + '[M-2H]+': -1 * Descriptors.ExactMolWt(h_2), # Loosing proton and hydrid + '[M-3H]+': -1 * Descriptors.ExactMolWt(h_2) - 1 * Descriptors.ExactMolWt(h_minus), # 2 Double bonds + H # experimental cases # "[M-4H]+": -1.007276 * 4, @@ -27,32 +27,32 @@ # negative fragment rearrangements # # "[M-H]-": -1*Chem.Descriptors.ExactMolWt(h_plus), # see above - "[M]-": 0, # could be one electron too many - "[M-2H]-": -1 * Descriptors.ExactMolWt(h_2), - "[M-3H]-": -1 * Descriptors.ExactMolWt(h_2) + '[M]-': 0, # could be one electron too many + '[M-2H]-': -1 * Descriptors.ExactMolWt(h_2), + '[M-3H]-': -1 * Descriptors.ExactMolWt(h_2) - 1 * Chem.Descriptors.ExactMolWt(h_plus), # # Hydrogen gains # - "[M+2H]+": Descriptors.ExactMolWt(h_plus) + '[M+2H]+': Descriptors.ExactMolWt(h_plus) + 1 * Descriptors.ExactMolWt( - Chem.MolFromSmiles("[H]") + Chem.MolFromSmiles('[H]') ), # 1 proton + 1 neutral hydrogens - "[M+3H]+": Descriptors.ExactMolWt(h_plus) + '[M+3H]+': Descriptors.ExactMolWt(h_plus) + 2 * Descriptors.ExactMolWt( - Chem.MolFromSmiles("[H]") + Chem.MolFromSmiles('[H]') ), # 1 proton + 2 neutral hydrogens - "[M+2H]-": Descriptors.ExactMolWt(h_plus) + '[M+2H]-': Descriptors.ExactMolWt(h_plus) + 1 * Descriptors.ExactMolWt( - Chem.MolFromSmiles("[H]") + Chem.MolFromSmiles('[H]') ), # 1 proton + 2 neutral hydrogens - "[M+3H]-": Descriptors.ExactMolWt(h_plus) + '[M+3H]-': Descriptors.ExactMolWt(h_plus) + 2 * Descriptors.ExactMolWt( - Chem.MolFromSmiles("[H]") + Chem.MolFromSmiles('[H]') ), # 1 proton + 2 neutral hydrogens } @@ -65,11 +65,11 @@ ) # DEFAULT_MODES = ["[M+H]+", "[M-H]+", "[M-3H]+"] DEFAULT_MODES = [ - "[M+H]+", - "[M]+", - "[M-H]+", - "[M-2H]+", - "[M-3H]+", + '[M+H]+', + '[M]+', + '[M-H]+', + '[M-2H]+', + '[M-3H]+', ] # "[M-4H]+"] #, "[M-5H]+"] DEFAULT_MODE_MAP = {mode: i for i, mode in enumerate(DEFAULT_MODES)} # NEGATIVE_MODES = ["[M]-", "[M-H]-", "[M-2H]-", "[M-3H]-", "[M-4H]-"] @@ -78,18 +78,18 @@ ORDERED_ELEMENT_LIST = [ - "Br", - "C", - "Cl", - "F", - "I", - "N", - "O", - "P", - "S", + 'Br', + 'C', + 'Cl', + 'F', + 'I', + 'N', + 'O', + 'P', + 'S', ] # Warning: Changes may affect model and version control ORDERED_ELEMENT_LIST_WITH_HYDROGEN = ORDERED_ELEMENT_LIST + [ - "H" + 'H' ] # Hydrogen is added at the end for element composition encoding MAX_SUBGRAPH_NODES = ( diff --git a/fiora/MOL/mol_graph.py b/fiora/MOL/mol_graph.py index 8b1727e..3f451c5 100644 --- a/fiora/MOL/mol_graph.py +++ b/fiora/MOL/mol_graph.py @@ -1,14 +1,13 @@ -import numpy as np import matplotlib.pyplot as plt -import torch - import networkx as nx +import numpy as np +import torch -node_color_map = {"C": "gray", "O": "red", "N": "blue"} +node_color_map = {'C': 'gray', 'O': 'red', 'N': 'blue'} -edge_color_map = {"SINGLE": "black", "DOUBLE": "black", "AROMATIC": "blue"} +edge_color_map = {'SINGLE': 'black', 'DOUBLE': 'black', 'AROMATIC': 'blue'} -edge_width_map = {"SINGLE": 1.5, "DOUBLE": 3, "AROMATIC": 3} +edge_width_map = {'SINGLE': 1.5, 'DOUBLE': 3, 'AROMATIC': 3} def mol_to_graph(mol): @@ -18,7 +17,7 @@ def mol_to_graph(mol): color = ( node_color_map[atom.GetSymbol()] if atom.GetSymbol() in node_color_map.keys() - else "black" + else 'black' ) G.add_node( atom.GetIdx(), @@ -47,20 +46,20 @@ def draw_graph(G, ax=None, edge_labels=False): G, ax=ax, pos=pos, - labels=nx.get_node_attributes(G, "atom_symbol"), + labels=nx.get_node_attributes(G, 'atom_symbol'), with_labels=True, - node_color=list(nx.get_node_attributes(G, "color").values()), + node_color=list(nx.get_node_attributes(G, 'color').values()), node_size=800, # edges=G.edges(), - edge_color=[edge_color_map[G[u][v]["bond_type"].name] for u, v in G.edges], - width=[edge_width_map[G[u][v]["bond_type"].name] for u, v in G.edges], + edge_color=[edge_color_map[G[u][v]['bond_type'].name] for u, v in G.edges], + width=[edge_width_map[G[u][v]['bond_type'].name] for u, v in G.edges], ) if edge_labels: nx.draw_networkx_edge_labels( G, pos, - edge_labels=dict([((u, v), f"({u}, {v})") for u, v in G.edges]), - font_color="red", + edge_labels=dict([((u, v), f'({u}, {v})') for u, v in G.edges]), + font_color='red', ax=ax, ) diff --git a/fiora/MS/SimulationFramework.py b/fiora/MS/SimulationFramework.py index 164d573..393a0f2 100644 --- a/fiora/MS/SimulationFramework.py +++ b/fiora/MS/SimulationFramework.py @@ -1,28 +1,30 @@ +from typing import Dict, Literal + +import numpy as np +import pandas as pd import torch import torch_geometric as geom -import pandas as pd -import numpy as np -from typing import Literal, Dict + +from fiora.MOL.constants import DEFAULT_MODE_MAP from fiora.MOL.Metabolite import Metabolite from fiora.MS.spectral_scores import ( + reweighted_dot, spectral_cosine, spectral_reflection_cosine, - reweighted_dot, ) -from fiora.MOL.constants import DEFAULT_MODE_MAP class SimulationFramework: - def __init__(self, base_model: torch.nn.Module | None = None, dev: str = "cpu"): + def __init__(self, base_model: torch.nn.Module | None = None, dev: str = 'cpu'): self.base_model = base_model self.dev = dev self.mode_map = None def __repr__(self): - return "Simulation framework for MS/MS spectrum generation" + return 'Simulation framework for MS/MS spectrum generation' def __str__(self): - return "Simulation framework for MS/MS spectrum generation" + return 'Simulation framework for MS/MS spectrum generation' def set_mode_mapper(self, mode_map): self.mode_map = mode_map @@ -38,8 +40,8 @@ def predict_metabolite_property( logits = model( data, - with_RT=hasattr(model, "rt_module"), - with_CCS=hasattr(model, "ccs_module"), + with_RT=hasattr(model, 'rt_module'), + with_CCS=hasattr(model, 'ccs_module'), ) return logits @@ -47,25 +49,25 @@ def pred_all( self, df: pd.DataFrame, model: torch.nn.Module | None = None, - attr_name: str = "", + attr_name: str = '', as_batch: bool = True, ): with torch.no_grad(): model.eval() for i, d in df.iterrows(): - metabolite = d["Metabolite"] + metabolite = d['Metabolite'] prediction = self.predict_metabolite_property( metabolite, model=model, as_batch=as_batch ) - if hasattr(model, "rt_module"): + if hasattr(model, 'rt_module'): setattr( - metabolite, attr_name + "_pred", prediction["fragment_probs"] + metabolite, attr_name + '_pred', prediction['fragment_probs'] ) - setattr(metabolite, "RT_pred", prediction["rt"].squeeze()) + setattr(metabolite, 'RT_pred', prediction['rt'].squeeze()) else: setattr( - metabolite, attr_name + "_pred", prediction["fragment_probs"] + metabolite, attr_name + '_pred', prediction['fragment_probs'] ) return @@ -73,10 +75,10 @@ def simulate_spectrum( self, metabolite: Metabolite, pred_label: str, - precursor_mode: Literal["[M+H]+", "[M-H]-"] = "[M+H]+", + precursor_mode: Literal['[M+H]+', '[M-H]-'] = '[M+H]+', min_intensity: float = 0.001, merge_fragment_duplicates: bool = True, - transform_prob: str = "None", + transform_prob: str = 'None', ): if not self.mode_map: @@ -87,16 +89,16 @@ def simulate_spectrum( edge_map = metabolite.fragmentation_tree.edge_map sim_probs = getattr(metabolite, pred_label) - sim_peaks = {"mz": [], "intensity": [], "annotation": []} + sim_peaks = {'mz': [], 'intensity': [], 'annotation': []} precursor_prob = sim_probs[-1].tolist() precursor = edge_map[None] - sim_peaks["mz"].append( + sim_peaks['mz'].append( precursor.mz[precursor_mode] ) # TODO allow multiple ion modes of precursor - sim_peaks["intensity"].append(precursor_prob) - sim_peaks["annotation"].append(precursor.smiles + "//" + precursor_mode) + sim_peaks['intensity'].append(precursor_prob) + sim_peaks['annotation'].append(precursor.smiles + '//' + precursor_mode) edge_probs = sim_probs[:-2].unflatten(-1, sizes=(-1, len(mode_map) * 2)) @@ -107,7 +109,7 @@ def simulate_spectrum( if not frags: continue - lf = frags.get("left") + lf = frags.get('left') if lf: for mode, idx in mode_map.items(): intensity = edge_probs[i, idx].tolist() @@ -115,30 +117,30 @@ def simulate_spectrum( mz = lf.mz[mode] mode_str = ( mode - if precursor_mode == "[M+H]+" - else mode.replace("]+", "]-") + if precursor_mode == '[M+H]+' + else mode.replace(']+', ']-') ) - annotation = lf.smiles + "//" + mode_str + annotation = lf.smiles + '//' + mode_str merged = False if merge_fragment_duplicates and ( - mz in sim_peaks["mz"] + mz in sim_peaks['mz'] ): # if exact mz value exists already - for j, mzx in enumerate(sim_peaks["mz"]): + for j, mzx in enumerate(sim_peaks['mz']): if ( mz == mzx - and annotation == sim_peaks["annotation"][j] + and annotation == sim_peaks['annotation'][j] ): # check mz and annotation - sim_peaks["intensity"][j] += ( + sim_peaks['intensity'][j] += ( intensity # and intensity if exact same fragments ) merged = True break if merged: continue - sim_peaks["mz"].append(mz) - sim_peaks["intensity"].append(intensity) - sim_peaks["annotation"].append(annotation) - rf = frags.get("right") + sim_peaks['mz'].append(mz) + sim_peaks['intensity'].append(intensity) + sim_peaks['annotation'].append(annotation) + rf = frags.get('right') if rf: for mode, idx in mode_map.items(): idx = (idx + len(mode_map)) % (2 * len(mode_map)) @@ -148,42 +150,42 @@ def simulate_spectrum( mz = rf.mz[mode] mode_str = ( mode - if precursor_mode == "[M+H]+" - else mode.replace("]+", "]-") + if precursor_mode == '[M+H]+' + else mode.replace(']+', ']-') ) - annotation = rf.smiles + "//" + mode_str + annotation = rf.smiles + '//' + mode_str merged = False - if merge_fragment_duplicates and (mz in sim_peaks["mz"]): - for j, mzx in enumerate(sim_peaks["mz"]): + if merge_fragment_duplicates and (mz in sim_peaks['mz']): + for j, mzx in enumerate(sim_peaks['mz']): if ( mz == mzx - and annotation == sim_peaks["annotation"][j] + and annotation == sim_peaks['annotation'][j] ): - sim_peaks["intensity"][j] += ( + sim_peaks['intensity'][j] += ( intensity # and intensity if exact same fragments ) merged = True break if merged: continue - sim_peaks["mz"].append(mz) - sim_peaks["intensity"].append(intensity) - sim_peaks["annotation"].append(annotation) + sim_peaks['mz'].append(mz) + sim_peaks['intensity'].append(intensity) + sim_peaks['annotation'].append(annotation) - if transform_prob == "square": - max_prob = max(sim_peaks["intensity"]) ** 2 - for i in range(len(sim_peaks["intensity"])): - sim_peaks["intensity"][i] = sim_peaks["intensity"][i] ** 2 / max_prob + if transform_prob == 'square': + max_prob = max(sim_peaks['intensity']) ** 2 + for i in range(len(sim_peaks['intensity'])): + sim_peaks['intensity'][i] = sim_peaks['intensity'][i] ** 2 / max_prob combined = sorted( - zip(sim_peaks["mz"], sim_peaks["intensity"], sim_peaks["annotation"]), + zip(sim_peaks['mz'], sim_peaks['intensity'], sim_peaks['annotation']), key=lambda t: t[0], reverse=True, ) mz, inten, annot = zip(*combined) - sim_peaks["mz"] = list(mz) - sim_peaks["intensity"] = list(inten) - sim_peaks["annotation"] = list(annot) + sim_peaks['mz'] = list(mz) + sim_peaks['intensity'] = list(inten) + sim_peaks['annotation'] = list(annot) return sim_peaks @@ -191,7 +193,7 @@ def simulate_and_score( self, metabolite: Metabolite, model: torch.nn.Module | None = None, - base_attr_name: str = "compiled_probsALL", + base_attr_name: str = 'compiled_probsALL', query_peaks: Dict | None = None, as_batch: bool = True, min_intensity: float = 0.001, @@ -201,25 +203,25 @@ def simulate_and_score( ) stats = {} - if "rt" in prediction.keys(): - stats["RT_pred"] = prediction["rt"].squeeze().tolist() - if "ccs" in prediction.keys(): - stats["CCS_pred"] = prediction["ccs"].squeeze().tolist() + if 'rt' in prediction.keys(): + stats['RT_pred'] = prediction['rt'].squeeze().tolist() + if 'ccs' in prediction.keys(): + stats['CCS_pred'] = prediction['ccs'].squeeze().tolist() - setattr(metabolite, base_attr_name + "_pred", prediction["fragment_probs"]) - training_label = model.model_params.get("training_label") + setattr(metabolite, base_attr_name + '_pred', prediction['fragment_probs']) + training_label = model.model_params.get('training_label') transform_prob = ( - "square" + 'square' if ( - training_label == "compiled_probsSQRT" - or (training_label is None and base_attr_name == "compiled_probsSQRT") + training_label == 'compiled_probsSQRT' + or (training_label is None and base_attr_name == 'compiled_probsSQRT') ) - else "None" + else 'None' ) - stats["sim_peaks"] = self.simulate_spectrum( + stats['sim_peaks'] = self.simulate_spectrum( metabolite, - base_attr_name + "_pred", - precursor_mode=metabolite.metadata["precursor_mode"], + base_attr_name + '_pred', + precursor_mode=metabolite.metadata['precursor_mode'], transform_prob=transform_prob, min_intensity=min_intensity, ) @@ -228,52 +230,52 @@ def simulate_and_score( if hasattr(metabolite, base_attr_name): groundtruth = getattr(metabolite, base_attr_name).to(self.dev) - stats["cosine_similarity"] = torch.nn.functional.cosine_similarity( - prediction["fragment_probs"], groundtruth, dim=0 + stats['cosine_similarity'] = torch.nn.functional.cosine_similarity( + prediction['fragment_probs'], groundtruth, dim=0 ).tolist() # TODO - stats["kl_div"] = torch.nn.functional.kl_div( - torch.log(prediction["fragment_probs"]), groundtruth, reduction="sum" + stats['kl_div'] = torch.nn.functional.kl_div( + torch.log(prediction['fragment_probs']), groundtruth, reduction='sum' ).tolist() - if "RT_pred" in stats.keys() and "retention_time" in metabolite.metadata.keys(): - stats["RT_dif"] = abs( - stats["RT_pred"] - metabolite.metadata["retention_time"] + if 'RT_pred' in stats.keys() and 'retention_time' in metabolite.metadata.keys(): + stats['RT_dif'] = abs( + stats['RT_pred'] - metabolite.metadata['retention_time'] ) if query_peaks: - stats["spectral_cosine"], stats["spectral_bias"] = spectral_cosine( - query_peaks, stats["sim_peaks"], with_bias=True + stats['spectral_cosine'], stats['spectral_bias'] = spectral_cosine( + query_peaks, stats['sim_peaks'], with_bias=True ) - stats["spectral_sqrt_cosine"], stats["spectral_sqrt_bias"] = ( + stats['spectral_sqrt_cosine'], stats['spectral_sqrt_bias'] = ( spectral_cosine( - query_peaks, stats["sim_peaks"], transform=np.sqrt, with_bias=True + query_peaks, stats['sim_peaks'], transform=np.sqrt, with_bias=True ) ) ( - stats["spectral_sqrt_cosine_wo_prec"], - stats["spectral_sqrt_bias_wo_prec"], + stats['spectral_sqrt_cosine_wo_prec'], + stats['spectral_sqrt_bias_wo_prec'], ) = spectral_cosine( query_peaks, - stats["sim_peaks"], + stats['sim_peaks'], transform=np.sqrt, remove_mz=metabolite.get_theoretical_precursor_mz( - ion_type=metabolite.metadata["precursor_mode"] + ion_type=metabolite.metadata['precursor_mode'] ), with_bias=True, ) - stats["spectral_sqrt_cosine_avg"], stats["spectral_sqrt_bias_avg"] = ( - (stats["spectral_sqrt_cosine"] + stats["spectral_sqrt_cosine_wo_prec"]) + stats['spectral_sqrt_cosine_avg'], stats['spectral_sqrt_bias_avg'] = ( + (stats['spectral_sqrt_cosine'] + stats['spectral_sqrt_cosine_wo_prec']) / 2.0, - (stats["spectral_sqrt_bias"] + stats["spectral_sqrt_bias_wo_prec"]) + (stats['spectral_sqrt_bias'] + stats['spectral_sqrt_bias_wo_prec']) / 2.0, ) - stats["spectral_refl_cosine"], stats["spectral_refl_bias"] = ( + stats['spectral_refl_cosine'], stats['spectral_refl_bias'] = ( spectral_reflection_cosine( - query_peaks, stats["sim_peaks"], transform=np.sqrt, with_bias=True + query_peaks, stats['sim_peaks'], transform=np.sqrt, with_bias=True ) ) - stats["steins_cosine"], stats["steins_bias"] = reweighted_dot( - query_peaks, stats["sim_peaks"], int_pow=0.5, mz_pow=0.5, with_bias=True + stats['steins_cosine'], stats['steins_bias'] = reweighted_dot( + query_peaks, stats['sim_peaks'], int_pow=0.5, mz_pow=0.5, with_bias=True ) return stats @@ -281,12 +283,12 @@ def simulate_all( self, df: pd.DataFrame, model: torch.nn.Module | None = None, - base_attr_name: str = "compiled_probsALL", - suffix: str = "", + base_attr_name: str = 'compiled_probsALL', + suffix: str = '', groundtruth=True, min_intensity: float = 0.001, progress: bool = False, - progress_desc: str = "Evaluate", + progress_desc: str = 'Evaluate', ): with torch.no_grad(): @@ -302,12 +304,12 @@ def simulate_all( pass for i, data in iterator: - metabolite = data["Metabolite"] + metabolite = data['Metabolite'] stats = self.simulate_and_score( metabolite, model, base_attr_name, - query_peaks=data["peaks"] if groundtruth else None, + query_peaks=data['peaks'] if groundtruth else None, min_intensity=min_intensity, ) df = pd.concat( @@ -320,7 +322,7 @@ def simulate_all( setattr(metabolite, key + suffix, value) else: raise Warning( - "User Warning: Attempting to add data to non-existing column simulate_all().\n\tSolve by adding column with pd.concat()" + 'User Warning: Attempting to add data to non-existing column simulate_all().\n\tSolve by adding column with pd.concat()' ) return df @@ -330,6 +332,7 @@ def plot_feature_prediction_vectors( ): import matplotlib.pyplot as plt + import fiora.visualization.spectrum_visualizer as sv if with_mol: @@ -337,7 +340,7 @@ def plot_feature_prediction_vectors( 1, 2, figsize=(12.8, 4.2), - gridspec_kw={"width_ratios": [1, 3]}, + gridspec_kw={'width_ratios': [1, 3]}, sharey=False, ) _ = metabolite.draw(ax=axs[0]) @@ -348,11 +351,11 @@ def plot_feature_prediction_vectors( metabolite.compiled_validation_mask, metabolite.compiled_forward_mask ) probs = getattr(metabolite, label).to(self.dev)[relevant_edge_index] - preds = getattr(metabolite, "predicted_" + label).to(self.dev)[ + preds = getattr(metabolite, 'predicted_' + label).to(self.dev)[ relevant_edge_index ] - names = [f"e{i}" for i in range(preds.shape[0] - 1)] + ["prec"] + names = [f'e{i}' for i in range(preds.shape[0] - 1)] + ['prec'] _ = sv.plot_vector_spectrum( probs.tolist(), preds.tolist(), ax=axs[1], names=names diff --git a/fiora/MS/ms_utility.py b/fiora/MS/ms_utility.py index f1b4db3..a84b74b 100644 --- a/fiora/MS/ms_utility.py +++ b/fiora/MS/ms_utility.py @@ -1,8 +1,10 @@ # from modules.MOL.FragmentationTree import FragmentationTree -from fiora.MOL.constants import PPM, DEFAULT_PPM, MIN_ABS_TOLERANCE +import copy from typing import Literal + import numpy as np -import copy + +from fiora.MOL.constants import DEFAULT_PPM, MIN_ABS_TOLERANCE, PPM def do_mz_values_match( @@ -42,37 +44,37 @@ def match_fragment_lists(mz_list, other_mz_list, tolerance=None): return uniques, multiples, unidentified -def normalize_spectrum(spec, type: Literal["max_intensity", "norm"] = "norm"): - if type == "max_intensity": - maximum = max(spec["intensity"]) - spec["intensity"] = [i / maximum for i in spec["intensity"]] - elif type == "norm": - spec["intensity"] = list( - np.array(spec["intensity"]) / np.linalg.norm(spec["intensity"]) +def normalize_spectrum(spec, type: Literal['max_intensity', 'norm'] = 'norm'): + if type == 'max_intensity': + maximum = max(spec['intensity']) + spec['intensity'] = [i / maximum for i in spec['intensity']] + elif type == 'norm': + spec['intensity'] = list( + np.array(spec['intensity']) / np.linalg.norm(spec['intensity']) ) else: - raise ValueError("Unknown type of normalization") + raise ValueError('Unknown type of normalization') def merge_annotated_spectrum(spec1, spec2): spec1 = copy.deepcopy(spec1) - spec2_red = {"mz": [], "intensity": [], "annotation": []} - for i, mz2 in enumerate(spec2["mz"]): + spec2_red = {'mz': [], 'intensity': [], 'annotation': []} + for i, mz2 in enumerate(spec2['mz']): merged_peak = False - if mz2 in spec1["mz"]: - for j, mz1 in enumerate(spec1["mz"]): - if mz1 == mz2 and spec1["annotation"][j] == spec2["annotation"][i]: - spec1["intensity"][j] += spec2["intensity"][i] + if mz2 in spec1['mz']: + for j, mz1 in enumerate(spec1['mz']): + if mz1 == mz2 and spec1['annotation'][j] == spec2['annotation'][i]: + spec1['intensity'][j] += spec2['intensity'][i] merged_peak = True break if not merged_peak: - spec2_red["mz"] += [spec2["mz"][i]] - spec2_red["intensity"] += [spec2["intensity"][i]] - spec2_red["annotation"] += [spec2["annotation"][i]] + spec2_red['mz'] += [spec2['mz'][i]] + spec2_red['intensity'] += [spec2['intensity'][i]] + spec2_red['annotation'] += [spec2['annotation'][i]] - spec1["mz"] += spec2_red["mz"] - spec1["intensity"] += spec2_red["intensity"] - spec1["annotation"] += spec2_red["annotation"] + spec1['mz'] += spec2_red['mz'] + spec1['intensity'] += spec2_red['intensity'] + spec1['annotation'] += spec2_red['annotation'] return spec1 @@ -81,24 +83,24 @@ def merge_spectrum(spec1, spec2, merge_tolerance: float = 0.0): spec1 = copy.deepcopy(spec1) if merge_tolerance > 0.01: raise Warning( - "Merging peaks recommended only for very small tolerances. Peak merging has mainly a visual impact is not needed for computation." + 'Merging peaks recommended only for very small tolerances. Peak merging has mainly a visual impact is not needed for computation.' ) - spec2_red = {"mz": [], "intensity": []} - for i, mz2 in enumerate(spec2["mz"]): + spec2_red = {'mz': [], 'intensity': []} + for i, mz2 in enumerate(spec2['mz']): merged_peak = False - for j, mz1 in enumerate(spec1["mz"]): + for j, mz1 in enumerate(spec1['mz']): if abs(mz1 - mz2) <= merge_tolerance: - spec1["intensity"][j] += spec2["intensity"][i] - spec1["mz"][j] = (mz1 + mz2) / 2 + spec1['intensity'][j] += spec2['intensity'][i] + spec1['mz'][j] = (mz1 + mz2) / 2 merged_peak = True break if not merged_peak: - spec2_red["mz"] += [spec2["mz"][i]] - spec2_red["intensity"] += [spec2["intensity"][i]] + spec2_red['mz'] += [spec2['mz'][i]] + spec2_red['intensity'] += [spec2['intensity'][i]] - spec1["mz"] += spec2_red["mz"] - spec1["intensity"] += spec2_red["intensity"] + spec1['mz'] += spec2_red['mz'] + spec1['intensity'] += spec2_red['intensity'] return spec1 diff --git a/fiora/MS/spectral_scores.py b/fiora/MS/spectral_scores.py index e1e5a4d..d560c42 100644 --- a/fiora/MS/spectral_scores.py +++ b/fiora/MS/spectral_scores.py @@ -1,6 +1,5 @@ import numpy as np - from fiora.MOL.constants import DEFAULT_DALTON @@ -57,14 +56,14 @@ def spectral_cosine( with_bias=False, remove_mz: float | None = None, ): - mz_map = create_mz_map(spec["mz"], spec_ref["mz"], tolerance=tolerance) + mz_map = create_mz_map(spec['mz'], spec_ref['mz'], tolerance=tolerance) vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) - bins = list(map(mz_map.get, spec["mz"])) - bins_ref = list(map(mz_map.get, spec_ref["mz"])) + bins = list(map(mz_map.get, spec['mz'])) + bins_ref = list(map(mz_map.get, spec_ref['mz'])) - np.add.at(vec, bins, spec["intensity"]) # vec.put(bins, spec["intensity"]) - np.add.at(vec_ref, bins_ref, spec_ref["intensity"]) + np.add.at(vec, bins, spec['intensity']) # vec.put(bins, spec["intensity"]) + np.add.at(vec_ref, bins_ref, spec_ref['intensity']) # zero out specific mz value, e.g. precursor m/z if remove_mz: @@ -91,14 +90,14 @@ def spectral_cosine( def spectral_reflection_cosine( spec, spec_ref, tolerance=DEFAULT_DALTON, transform=None, with_bias=False ): - mz_map = create_mz_map(spec["mz"], spec_ref["mz"], tolerance=tolerance) + mz_map = create_mz_map(spec['mz'], spec_ref['mz'], tolerance=tolerance) vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) - bins = list(map(mz_map.get, spec["mz"])) - bins_ref = list(map(mz_map.get, spec_ref["mz"])) + bins = list(map(mz_map.get, spec['mz'])) + bins_ref = list(map(mz_map.get, spec_ref['mz'])) - np.add.at(vec, bins, spec["intensity"]) # vec.put(bins, spec["intensity"]) - np.add.at(vec_ref, bins_ref, spec_ref["intensity"]) + np.add.at(vec, bins, spec['intensity']) # vec.put(bins, spec["intensity"]) + np.add.at(vec_ref, bins_ref, spec_ref['intensity']) # Reflection score: Remove values that are not matched with the reference values unmatched_bins = [b for b in bins if b not in bins_ref] @@ -119,23 +118,23 @@ def spectral_reflection_cosine( def reweighted_dot( spec, spec_ref, int_pow=0.5, mz_pow=0.5, tolerance=DEFAULT_DALTON, with_bias=False ): - mz_map = create_mz_map(spec["mz"], spec_ref["mz"], tolerance=tolerance) + mz_map = create_mz_map(spec['mz'], spec_ref['mz'], tolerance=tolerance) vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map)) - bins = list(map(mz_map.get, spec["mz"])) - bins_ref = list(map(mz_map.get, spec_ref["mz"])) + bins = list(map(mz_map.get, spec['mz'])) + bins_ref = list(map(mz_map.get, spec_ref['mz'])) - spec["mz_int"] = [ - np.power(spec["intensity"][i], int_pow) * np.power(mz, mz_pow) - for i, mz in enumerate(spec["mz"]) + spec['mz_int'] = [ + np.power(spec['intensity'][i], int_pow) * np.power(mz, mz_pow) + for i, mz in enumerate(spec['mz']) ] - spec_ref["mz_int"] = [ - np.power(spec_ref["intensity"][i], int_pow) * np.power(mz, mz_pow) - for i, mz in enumerate(spec_ref["mz"]) + spec_ref['mz_int'] = [ + np.power(spec_ref['intensity'][i], int_pow) * np.power(mz, mz_pow) + for i, mz in enumerate(spec_ref['mz']) ] - np.add.at(vec, bins, spec["mz_int"]) - np.add.at(vec_ref, bins_ref, spec_ref["mz_int"]) + np.add.at(vec, bins, spec['mz_int']) + np.add.at(vec_ref, bins_ref, spec_ref['mz_int']) cos = cosine(vec, vec_ref) if with_bias: diff --git a/fiora/cli/eval.py b/fiora/cli/eval.py index 94a1124..b3ae187 100644 --- a/fiora/cli/eval.py +++ b/fiora/cli/eval.py @@ -21,108 +21,108 @@ from fiora.MOL.MetaboliteIndex import MetaboliteIndex from fiora.MS.SimulationFramework import SimulationFramework -RDLogger.DisableLog("rdApp.*") -warnings.filterwarnings("ignore", category=SyntaxWarning) +RDLogger.DisableLog('rdApp.*') +warnings.filterwarnings('ignore', category=SyntaxWarning) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - prog="fiora-eval", - description="Evaluate a trained FIORA model on validation/test splits.", + prog='fiora-eval', + description='Evaluate a trained FIORA model on validation/test splits.', ) parser.add_argument( - "-i", - "--input", + '-i', + '--input', required=True, - help="Path to preprocessed CSV containing spectra/metadata/SMILES.", + help='Path to preprocessed CSV containing spectra/metadata/SMILES.', ) parser.add_argument( - "-m", - "--model", + '-m', + '--model', required=True, - help="Path to checkpoint .pt produced by fiora-train.", + help='Path to checkpoint .pt produced by fiora-train.', ) parser.add_argument( - "--device", - default="auto", - help="Device to run on (e.g. cpu, cuda:0). Default: auto.", + '--device', + default='auto', + help='Device to run on (e.g. cpu, cuda:0). Default: auto.', ) parser.add_argument( - "--datasplit-col", - default="datasplit", - help="Column containing split labels (default: datasplit).", + '--datasplit-col', + default='datasplit', + help='Column containing split labels (default: datasplit).', ) parser.add_argument( - "--splits", - default="validation,test", - help="Comma-separated splits to evaluate (default: validation,test).", + '--splits', + default='validation,test', + help='Comma-separated splits to evaluate (default: validation,test).', ) parser.add_argument( - "--score", - default="spectral_sqrt_cosine", - help="Score column to summarize after evaluation.", + '--score', + default='spectral_sqrt_cosine', + help='Score column to summarize after evaluation.', ) parser.add_argument( - "--print-wo-prec", + '--print-wo-prec', action=argparse.BooleanOptionalAction, default=True, - help="Also print precursor-excluded score summaries when available (default: true).", + help='Also print precursor-excluded score summaries when available (default: true).', ) parser.add_argument( - "--y-label", + '--y-label', default=None, - help="Prediction target label used during training (default: from model params, fallback compiled_probsALL).", + help='Prediction target label used during training (default: from model params, fallback compiled_probsALL).', ) parser.add_argument( - "--min-prob", + '--min-prob', type=float, default=0.001, - help="Minimum predicted peak intensity to keep.", + help='Minimum predicted peak intensity to keep.', ) parser.add_argument( - "--fragmentation-depth", + '--fragmentation-depth', type=int, default=1, - help="Fragmentation depth for metabolite trees.", + help='Fragmentation depth for metabolite trees.', ) parser.add_argument( - "--graph-mismatch-policy", - choices=["recompute", "ignore"], - default="recompute", + '--graph-mismatch-policy', + choices=['recompute', 'ignore'], + default='recompute', ) - parser.add_argument("--summary-col", default="summary") - parser.add_argument("--peaks-col", default="peaks") - parser.add_argument("--smiles-col", default="SMILES") - parser.add_argument("--group-id-col", default="group_id") - parser.add_argument("--max-rows", type=int, default=None) + parser.add_argument('--summary-col', default='summary') + parser.add_argument('--peaks-col', default='peaks') + parser.add_argument('--smiles-col', default='SMILES') + parser.add_argument('--group-id-col', default='group_id') + parser.add_argument('--max-rows', type=int, default=None) parser.add_argument( - "--output-dir", + '--output-dir', default=None, - help="Optional directory to write evaluated split CSV files.", + help='Optional directory to write evaluated split CSV files.', ) parser.add_argument( - "--progress", + '--progress', action=argparse.BooleanOptionalAction, default=True, - help="Show tqdm progress bars (default: true).", + help='Show tqdm progress bars (default: true).', ) parser.add_argument( - "--index-col", + '--index-col', type=int, default=0, - help="CSV index column (default: 0). Use --no-index-col to disable.", + help='CSV index column (default: 0). Use --no-index-col to disable.', ) parser.add_argument( - "--no-index-col", - action="store_true", - help="Disable index_col when reading CSV.", + '--no-index-col', + action='store_true', + help='Disable index_col when reading CSV.', ) return parser.parse_args() def _resolve_device(device: str) -> str: - if device == "auto": - return "cuda:0" if torch.cuda.is_available() else "cpu" + if device == 'auto': + return 'cuda:0' if torch.cuda.is_available() else 'cpu' return device @@ -138,9 +138,9 @@ def _parse_dict(val): return json.loads(text) except Exception: pass - norm = re.sub(r"\b(?:NaN|nan)\b", "None", text) - norm = re.sub(r"\b(?:Infinity|inf)\b", "1e309", norm) - norm = re.sub(r"\b(?:-Infinity|-inf)\b", "-1e309", norm) + norm = re.sub(r'\b(?:NaN|nan)\b', 'None', text) + norm = re.sub(r'\b(?:Infinity|inf)\b', '1e309', norm) + norm = re.sub(r'\b(?:-Infinity|-inf)\b', '-1e309', norm) try: parsed = ast.literal_eval(norm) return parsed if isinstance(parsed, dict) else None @@ -164,13 +164,13 @@ def _safe_metabolite(smiles: str): def _build_summary_from_columns(row): metadata_key_map = { - "name": ["Name", "NAME", "Title", "TITLE"], - "collision_energy": ["CE", "COLLISION_ENERGY", "CollisionEnergy"], - "instrument": ["Instrument_type", "instrument", "INSTRUMENT_TYPE"], - "precursor_mode": ["Precursor_type", "ADDUCT", "PRECURSORTYPE"], - "precursor_mz": ["PrecursorMZ", "PEPMASS", "PRECURSORMZ"], - "retention_time": ["RETENTIONTIME", "RTINSECONDS", "retention_time"], - "ccs": ["CCS", "ccs"], + 'name': ['Name', 'NAME', 'Title', 'TITLE'], + 'collision_energy': ['CE', 'COLLISION_ENERGY', 'CollisionEnergy'], + 'instrument': ['Instrument_type', 'instrument', 'INSTRUMENT_TYPE'], + 'precursor_mode': ['Precursor_type', 'ADDUCT', 'PRECURSORTYPE'], + 'precursor_mz': ['PrecursorMZ', 'PEPMASS', 'PRECURSORMZ'], + 'retention_time': ['RETENTIONTIME', 'RTINSECONDS', 'retention_time'], + 'ccs': ['CCS', 'ccs'], } summary = {} for key, cols in metadata_key_map.items(): @@ -189,25 +189,25 @@ def _prepare_metabolites( df: pd.DataFrame, model, progress: bool = True ) -> tuple[pd.DataFrame, int]: setup_features = model.model_params.get( - "setup_features", + 'setup_features', [ - "collision_energy", - "molecular_weight", - "precursor_mode", - "instrument", - "element_composition", + 'collision_energy', + 'molecular_weight', + 'precursor_mode', + 'instrument', + 'element_composition', ], ) rt_features = model.model_params.get( - "rt_features", - ["molecular_weight", "precursor_mode", "instrument", "element_composition"], + 'rt_features', + ['molecular_weight', 'precursor_mode', 'instrument', 'element_composition'], ) - setup_sets = model.model_params.get("setup_features_categorical_set") + setup_sets = model.model_params.get('setup_features_categorical_set') node_encoder = AtomFeatureEncoder( - feature_list=["symbol", "num_hydrogen", "ring_type"] + feature_list=['symbol', 'num_hydrogen', 'ring_type'] ) - bond_encoder = BondFeatureEncoder(feature_list=["bond_type", "ring_type"]) + bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type']) setup_encoder = CovariateFeatureEncoder( feature_list=setup_features, sets_overwrite=setup_sets ) @@ -221,12 +221,12 @@ def _prepare_metabolites( try: from tqdm.auto import tqdm - iterator = tqdm(iterator, total=len(df), desc="Prepare metabolites") + iterator = tqdm(iterator, total=len(df), desc='Prepare metabolites') except Exception: pass for idx, row in iterator: - smiles = row.get("SMILES") + smiles = row.get('SMILES') if smiles is None or (isinstance(smiles, float) and np.isnan(smiles)): invalid_rows.append(idx) continue @@ -237,13 +237,13 @@ def _prepare_metabolites( mol.create_molecular_structure_graph() mol.compute_graph_attributes(node_encoder, bond_encoder) - if "group_id" in df.columns: + if 'group_id' in df.columns: try: - mol.set_id(int(row["group_id"])) + mol.set_id(int(row['group_id'])) except Exception: pass - summary = row.get("summary") + summary = row.get('summary') if summary is None: summary = _build_summary_from_columns(row) @@ -252,7 +252,7 @@ def _prepare_metabolites( except Exception: invalid_rows.append(idx) continue - df.at[idx, "Metabolite"] = mol + df.at[idx, 'Metabolite'] = mol if invalid_rows: df = df.drop(index=invalid_rows).copy() @@ -260,8 +260,8 @@ def _prepare_metabolites( def _load_model(path: str, dev: str): - state_path = path.replace(".pt", "_state.pt") - params_path = path.replace(".pt", "_params.json") + state_path = path.replace('.pt', '_state.pt') + params_path = path.replace('.pt', '_params.json') if os.path.exists(state_path) and os.path.exists(params_path): return FioraModel.load_from_state_dict(path).to(dev) return FioraModel.load(path).to(dev) @@ -269,10 +269,10 @@ def _load_model(path: str, dev: str): def _to_csv_safe(df: pd.DataFrame) -> pd.DataFrame: out = df.copy() - if "Metabolite" in out.columns: - out = out.drop(columns=["Metabolite"]) + if 'Metabolite' in out.columns: + out = out.drop(columns=['Metabolite']) for col in out.columns: - if out[col].dtype == "object": + if out[col].dtype == 'object': out[col] = out[col].apply( lambda v: json.dumps(v) if isinstance(v, (dict, list)) else v ) @@ -282,14 +282,14 @@ def _to_csv_safe(df: pd.DataFrame) -> pd.DataFrame: def _metric_stats(part: pd.DataFrame, metric: str) -> tuple[float, float] | None: if metric not in part.columns: return None - vals = pd.to_numeric(part[metric], errors="coerce") + vals = pd.to_numeric(part[metric], errors='coerce') return float(vals.mean()), float(vals.median()) def main() -> None: args = parse_args() dev = _resolve_device(args.device) - np.seterr(invalid="ignore") + np.seterr(invalid='ignore') index_col = None if args.no_index_col else args.index_col loader = LibraryLoader() @@ -303,58 +303,58 @@ def main() -> None: df = df.iloc[: args.max_rows].copy() df = _parse_dict_columns(df, [args.summary_col, args.peaks_col]) - splits = [x.strip() for x in args.splits.split(",") if x.strip()] + splits = [x.strip() for x in args.splits.split(',') if x.strip()] if not splits: - raise SystemExit("No valid --splits provided.") + raise SystemExit('No valid --splits provided.') if args.datasplit_col not in df.columns: raise SystemExit(f"datasplit column '{args.datasplit_col}' not found in input.") df = df[df[args.datasplit_col].isin(splits)].copy() - print(f"Loaded {len(df)} rows for splits: {splits}") + print(f'Loaded {len(df)} rows for splits: {splits}') if len(df) == 0: - raise SystemExit("No rows left after split filtering.") + raise SystemExit('No rows left after split filtering.') model = _load_model(args.model, dev) model.eval() y_label = args.y_label or model.model_params.get( - "training_label", "compiled_probsALL" + 'training_label', 'compiled_probsALL' ) if args.y_label is None: - print(f"Using y-label from model params: {y_label}") - elif model.model_params.get("training_label") and y_label != model.model_params.get( - "training_label" + print(f'Using y-label from model params: {y_label}') + elif model.model_params.get('training_label') and y_label != model.model_params.get( + 'training_label' ): print( - "Warning: --y-label does not match model training label " - f"({y_label} vs {model.model_params.get('training_label')})." + 'Warning: --y-label does not match model training label ' + f'({y_label} vs {model.model_params.get("training_label")}).' ) # Standardize user-configurable column names for downstream code. - if args.summary_col != "summary" and args.summary_col in df.columns: - df["summary"] = df[args.summary_col] - if args.peaks_col != "peaks" and args.peaks_col in df.columns: - df["peaks"] = df[args.peaks_col] - if args.smiles_col != "SMILES" and args.smiles_col in df.columns: - df["SMILES"] = df[args.smiles_col] - if args.group_id_col != "group_id" and args.group_id_col in df.columns: - df["group_id"] = df[args.group_id_col] + if args.summary_col != 'summary' and args.summary_col in df.columns: + df['summary'] = df[args.summary_col] + if args.peaks_col != 'peaks' and args.peaks_col in df.columns: + df['peaks'] = df[args.peaks_col] + if args.smiles_col != 'SMILES' and args.smiles_col in df.columns: + df['SMILES'] = df[args.smiles_col] + if args.group_id_col != 'group_id' and args.group_id_col in df.columns: + df['group_id'] = df[args.group_id_col] df, dropped = _prepare_metabolites(df, model, progress=args.progress) if dropped: - print(f"Dropped {dropped} invalid rows during metabolite preparation.") + print(f'Dropped {dropped} invalid rows during metabolite preparation.') mindex = MetaboliteIndex() - mindex.index_metabolites(df["Metabolite"]) + mindex.index_metabolites(df['Metabolite']) mindex.create_fragmentation_trees(depth=args.fragmentation_depth) mindex.add_fragmentation_trees_to_metabolite_list( - df["Metabolite"], graph_mismatch_policy=args.graph_mismatch_policy + df['Metabolite'], graph_mismatch_policy=args.graph_mismatch_policy ) fiora = SimulationFramework(None, dev=dev) - use_groundtruth = "peaks" in df.columns + use_groundtruth = 'peaks' in df.columns if not use_groundtruth: print( - "Warning: peaks column not found. Running prediction without score metrics." + 'Warning: peaks column not found. Running prediction without score metrics.' ) output_dir = None @@ -375,12 +375,12 @@ def main() -> None: groundtruth=use_groundtruth, min_intensity=args.min_prob, progress=args.progress, - progress_desc=f"{split} split", + progress_desc=f'{split} split', ) metrics_to_report = [args.score] if args.print_wo_prec: - for metric in ["spectral_sqrt_cosine_wo_prec", "spectral_sqrt_cosine_avg"]: + for metric in ['spectral_sqrt_cosine_wo_prec', 'spectral_sqrt_cosine_avg']: if metric != args.score: metrics_to_report.append(metric) @@ -393,14 +393,14 @@ def main() -> None: continue mean, median = stats summary_table.setdefault(metric, {})[split] = (mean, median) - summaries.append(f"{metric}_mean={mean:.5f} | {metric}_median={median:.5f}") + summaries.append(f'{metric}_mean={mean:.5f} | {metric}_median={median:.5f}') - print(f"Split '{split}': n={len(part)} | " + " | ".join(summaries)) + print(f"Split '{split}': n={len(part)} | " + ' | '.join(summaries)) if output_dir is not None: - out_path = output_dir / f"{split}_eval.csv" + out_path = output_dir / f'{split}_eval.csv' _to_csv_safe(part).to_csv(out_path, index=False) - print(f"Wrote {len(part)} rows to {out_path}") + print(f'Wrote {len(part)} rows to {out_path}') if summary_table: table = pd.DataFrame( @@ -408,11 +408,11 @@ def main() -> None: ) for metric, split_stats in summary_table.items(): for split, (mean, median) in split_stats.items(): - table.at[metric, split] = f"{mean:.5f} / {median:.5f}" - table = table.fillna("-") - print("\nSummary Table (mean / median):") + table.at[metric, split] = f'{mean:.5f} / {median:.5f}' + table = table.fillna('-') + print('\nSummary Table (mean / median):') print(table.to_string()) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/fiora/cli/model_info.py b/fiora/cli/model_info.py index 0a48d64..696d0be 100644 --- a/fiora/cli/model_info.py +++ b/fiora/cli/model_info.py @@ -8,86 +8,86 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - prog="fiora-model-info", - description="Load a FIORA .pt checkpoint and print key model parameters.", + prog='fiora-model-info', + description='Load a FIORA .pt checkpoint and print key model parameters.', ) parser.add_argument( - "-m", - "--model", + '-m', + '--model', required=True, - help="Path to checkpoint .pt file.", + help='Path to checkpoint .pt file.', ) parser.add_argument( - "--as-json", + '--as-json', action=argparse.BooleanOptionalAction, default=False, - help="Print the full model_params dictionary as JSON.", + help='Print the full model_params dictionary as JSON.', ) return parser.parse_args() def _load_model(path: str): - state_path = path.replace(".pt", "_state.pt") - params_path = path.replace(".pt", "_params.json") + state_path = path.replace('.pt', '_state.pt') + params_path = path.replace('.pt', '_params.json') if os.path.exists(state_path) and os.path.exists(params_path): return FioraModel.load_from_state_dict(path) return FioraModel.load(path) -def _get(params: dict, key: str, default="-"): +def _get(params: dict, key: str, default='-'): return params.get(key, default) def _print_main_params(params: dict) -> None: keys = [ - "version", - "version_number", - "param_tag", - "training_label", - "gnn_type", - "depth", - "hidden_dimension", - "embedding_dimension", - "embedding_aggregation", - "layer_stacking", - "residual_connections", - "layer_norm", - "subgraph_features", - "pooling_func", - "dense_layers", - "dense_dim", - "input_dropout", - "latent_dropout", - "output_dimension", - "static_feature_dimension", - "static_rt_feature_dimension", - "prepare_additional_layers", - "rt_supported", - "ccs_supported", + 'version', + 'version_number', + 'param_tag', + 'training_label', + 'gnn_type', + 'depth', + 'hidden_dimension', + 'embedding_dimension', + 'embedding_aggregation', + 'layer_stacking', + 'residual_connections', + 'layer_norm', + 'subgraph_features', + 'pooling_func', + 'dense_layers', + 'dense_dim', + 'input_dropout', + 'latent_dropout', + 'output_dimension', + 'static_feature_dimension', + 'static_rt_feature_dimension', + 'prepare_additional_layers', + 'rt_supported', + 'ccs_supported', ] - print("Model parameters:") + print('Model parameters:') for key in keys: - print(f" {key}: {_get(params, key)}") + print(f' {key}: {_get(params, key)}') # concise feature summary - atom_features = _get(params, "atom_features", []) - setup_features = _get(params, "setup_features", []) - rt_features = _get(params, "rt_features", []) - print(" atom_features:", atom_features if atom_features else "-") - print(" setup_features:", setup_features if setup_features else "-") - print(" rt_features:", rt_features if rt_features else "-") + atom_features = _get(params, 'atom_features', []) + setup_features = _get(params, 'setup_features', []) + rt_features = _get(params, 'rt_features', []) + print(' atom_features:', atom_features if atom_features else '-') + print(' setup_features:', setup_features if setup_features else '-') + print(' rt_features:', rt_features if rt_features else '-') def main() -> None: args = parse_args() model = _load_model(args.model) - params = model.model_params if hasattr(model, "model_params") else {} - print(f"Loaded model: {args.model}") + params = model.model_params if hasattr(model, 'model_params') else {} + print(f'Loaded model: {args.model}') if args.as_json: print(json.dumps(params, indent=2, sort_keys=True)) return _print_main_params(params) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/fiora/cli/predict.py b/fiora/cli/predict.py index 71cab0b..db7ba25 100644 --- a/fiora/cli/predict.py +++ b/fiora/cli/predict.py @@ -3,7 +3,7 @@ import importlib.resources as resources import warnings -warnings.filterwarnings("ignore", category=SyntaxWarning) +warnings.filterwarnings('ignore', category=SyntaxWarning) import pandas as pd from rdkit import RDLogger @@ -17,78 +17,78 @@ from fiora.MOL.Metabolite import Metabolite from fiora.MS.SimulationFramework import SimulationFramework -RDLogger.DisableLog("rdApp.*") +RDLogger.DisableLog('rdApp.*') warnings.filterwarnings( - "ignore", category=UserWarning, message="TypedStorage is deprecated" + 'ignore', category=UserWarning, message='TypedStorage is deprecated' ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - prog="fiora-predict", + prog='fiora-predict', description=( - "Fiora is an in silico fragmentation framework, which predicts peaks and " - "simulates tandem mass spectra including features such as retention time " - "and collision cross sections. Use this script for spectrum predictions " - "with a (pre-)trained model." + 'Fiora is an in silico fragmentation framework, which predicts peaks and ' + 'simulates tandem mass spectra including features such as retention time ' + 'and collision cross sections. Use this script for spectrum predictions ' + 'with a (pre-)trained model.' ), - epilog="Disclaimer:\nNo prediction software is perfect. Use with caution.", + epilog='Disclaimer:\nNo prediction software is perfect. Use with caution.', ) parser.add_argument( - "-i", - "--input", - help="Input file containing molecular structures (SMILES/InChi) and metadata (.csv file)", + '-i', + '--input', + help='Input file containing molecular structures (SMILES/InChi) and metadata (.csv file)', type=str, required=True, ) parser.add_argument( - "-o", - "--output", - help="Output file path (.mgf/.msp file)", + '-o', + '--output', + help='Output file path (.mgf/.msp file)', type=str, required=True, ) parser.add_argument( - "--model", - help="Path to prediction model (.pt file)", + '--model', + help='Path to prediction model (.pt file)', type=str, - default="default", + default='default', ) parser.add_argument( - "--dev", - help="Device to the model. For example cuda:0 for GPU number 0.", + '--dev', + help='Device to the model. For example cuda:0 for GPU number 0.', type=str, - default="cpu", + default='cpu', ) parser.add_argument( - "--min_prob", - help="Minimum peak probability to be recorded in the spectrum", + '--min_prob', + help='Minimum peak probability to be recorded in the spectrum', type=float, default=0.001, ) parser.add_argument( - "--rt", + '--rt', action=argparse.BooleanOptionalAction, - help="Predict retention time", + help='Predict retention time', default=False, ) parser.add_argument( - "--ccs", + '--ccs', action=argparse.BooleanOptionalAction, - help="Predict collison cross section", + help='Predict collison cross section', default=False, ) parser.add_argument( - "--annotation", + '--annotation', action=argparse.BooleanOptionalAction, - help="Annotate predicted peaks with SMILES strings", + help='Annotate predicted peaks with SMILES strings', default=False, ) parser.add_argument( - "--debug", + '--debug', action=argparse.BooleanOptionalAction, - help="Receive debug information", + help='Receive debug information', default=False, ) return parser.parse_args() @@ -97,36 +97,36 @@ def parse_args() -> argparse.Namespace: def update_args_with_model_params( args: argparse.Namespace, model_params: dict ) -> argparse.Namespace: - if "rt_supported" in model_params.keys(): - if not model_params["rt_supported"] and args.rt: + if 'rt_supported' in model_params.keys(): + if not model_params['rt_supported'] and args.rt: print( - "Warning: RT prediction is not support by the model. Overwriting user argument to --no-rt.\n" + 'Warning: RT prediction is not support by the model. Overwriting user argument to --no-rt.\n' ) args.rt = False - if "ccs_supported" in model_params.keys(): - if not model_params["ccs_supported"] and args.ccs: + if 'ccs_supported' in model_params.keys(): + if not model_params['ccs_supported'] and args.ccs: print( - "Warning: CCS prediction is not support by the model. Overwriting user argument to --no-ccs.\n" + 'Warning: CCS prediction is not support by the model. Overwriting user argument to --no-ccs.\n' ) args.ccs = False return args def print_model_messages(model_params: dict) -> None: - if "version" in model_params.keys(): - print("\n-----Model-----") - print(model_params["version"]) - print("---------------") - if "disclaimer" in model_params.keys(): - dis_msg = model_params["disclaimer"] - print(f"\nDisclaimer: {dis_msg}") + if 'version' in model_params.keys(): + print('\n-----Model-----') + print(model_params['version']) + print('---------------') + if 'disclaimer' in model_params.keys(): + dis_msg = model_params['disclaimer'] + print(f'\nDisclaimer: {dis_msg}') metadata_key_map = { - "name": "Name", - "collision_energy": "CE", - "instrument": "Instrument_type", - "precursor_mode": "Precursor_type", + 'name': 'Name', + 'collision_energy': 'CE', + 'instrument': 'Instrument_type', + 'precursor_mode': 'Precursor_type', } @@ -143,103 +143,103 @@ def build_metabolites(df: pd.DataFrame, model_params: dict): weight_upper_limit = 1000.0 model_setup_feature_sets = None - if "setup_features_categorical_set" in model_params.keys(): - model_setup_feature_sets = model_params["setup_features_categorical_set"] + if 'setup_features_categorical_set' in model_params.keys(): + model_setup_feature_sets = model_params['setup_features_categorical_set'] node_encoder = AtomFeatureEncoder( - feature_list=["symbol", "num_hydrogen", "ring_type"] + feature_list=['symbol', 'num_hydrogen', 'ring_type'] ) - bond_encoder = BondFeatureEncoder(feature_list=["bond_type", "ring_type"]) - if model_params["version_number"] == "0.1.0": + bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type']) + if model_params['version_number'] == '0.1.0': covariate_features = [ - "collision_energy", - "molecular_weight", - "precursor_mode", - "instrument", + 'collision_energy', + 'molecular_weight', + 'precursor_mode', + 'instrument', ] else: covariate_features = [ - "collision_energy", - "molecular_weight", - "precursor_mode", - "instrument", - "element_composition", + 'collision_energy', + 'molecular_weight', + 'precursor_mode', + 'instrument', + 'element_composition', ] setup_encoder = CovariateFeatureEncoder( feature_list=covariate_features, sets_overwrite=model_setup_feature_sets ) rt_encoder = CovariateFeatureEncoder( - feature_list=["molecular_weight", "precursor_mode", "instrument"], + feature_list=['molecular_weight', 'precursor_mode', 'instrument'], sets_overwrite=model_setup_feature_sets, ) - setup_encoder.normalize_features["collision_energy"]["max"] = ce_upper_limit - setup_encoder.normalize_features["molecular_weight"]["max"] = weight_upper_limit - rt_encoder.normalize_features["molecular_weight"]["max"] = weight_upper_limit + setup_encoder.normalize_features['collision_energy']['max'] = ce_upper_limit + setup_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit + rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit # Convert SMILES to Metabolites and create structure graphs and fragmentation trees - df["Metabolite"] = df["SMILES"].apply(safe_metabolite_creation) - invalid_df = df[df["Metabolite"].isna()][["Name", "SMILES"]] - df.dropna(subset=["Metabolite"], inplace=True) + df['Metabolite'] = df['SMILES'].apply(safe_metabolite_creation) + invalid_df = df[df['Metabolite'].isna()][['Name', 'SMILES']] + df.dropna(subset=['Metabolite'], inplace=True) - df["Metabolite"].apply(lambda x: x.create_molecular_structure_graph()) - df["Metabolite"].apply( + df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph()) + df['Metabolite'].apply( lambda x: x.compute_graph_attributes(node_encoder, bond_encoder) ) # Map covariate features to dedicated format and encode - df["summary"] = df.apply( + df['summary'] = df.apply( lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1 ) df.apply( - lambda x: x["Metabolite"].add_metadata(x["summary"], setup_encoder, rt_encoder), + lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder), axis=1, ) # Fragment compounds - df["Metabolite"].apply(lambda x: x.fragment_MOL(depth=1)) + df['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1)) return df, invalid_df def prepare_output(args, df, model): - df["peaks"] = df["sim_peaks"] - df["Formula"] = df["Metabolite"].apply(lambda x: x.Formula) - df["Precursor_MZ"] = df["Metabolite"].apply( - lambda x: x.get_theoretical_precursor_mz(ion_type=x.metadata["precursor_mode"]) + df['peaks'] = df['sim_peaks'] + df['Formula'] = df['Metabolite'].apply(lambda x: x.Formula) + df['Precursor_MZ'] = df['Metabolite'].apply( + lambda x: x.get_theoretical_precursor_mz(ion_type=x.metadata['precursor_mode']) ) # Rename certain columns - if "RT_pred" in df.columns: - df["RETENTIONTIME"] = df["RT_pred"] - df["PRECURSOR_MZ"] = df["Precursor_MZ"] - df["FORMULA"] = df["Formula"] - if "CCS_pred" in df.columns: - df["CCS"] = df["CCS_pred"] + if 'RT_pred' in df.columns: + df['RETENTIONTIME'] = df['RT_pred'] + df['PRECURSOR_MZ'] = df['Precursor_MZ'] + df['FORMULA'] = df['Formula'] + if 'CCS_pred' in df.columns: + df['CCS'] = df['CCS_pred'] version = ( - model.model_params["version"] - if "version" in model.model_params - else "(pre-release version v0.0.0)" + model.model_params['version'] + if 'version' in model.model_params + else '(pre-release version v0.0.0)' ) - df["Comment"] = f'"In silico generated spectrum by {version}"' - df["COMMENT"] = df["Comment"] + df['Comment'] = f'"In silico generated spectrum by {version}"' + df['COMMENT'] = df['Comment'] # Write output file - if args.output.endswith(".msp"): - df["Collision_energy"] = df["CE"] + if args.output.endswith('.msp'): + df['Collision_energy'] = df['CE'] headers = [ - "Name", - "SMILES", - "Formula", - "Precursor_MZ", - "Precursor_type", - "Instrument_type", - "Collision_energy", + 'Name', + 'SMILES', + 'Formula', + 'Precursor_MZ', + 'Precursor_type', + 'Instrument_type', + 'Collision_energy', ] if args.rt: - headers.append("RETENTIONTIME") + headers.append('RETENTIONTIME') if args.ccs: - headers.append("CCS") - headers.append("Comment") + headers.append('CCS') + headers.append('Comment') mspWriter.write_msp( df, path=args.output, @@ -247,62 +247,62 @@ def prepare_output(args, df, model): headers=headers, annotation=args.annotation, ) - elif args.output.endswith(".mgf"): + elif args.output.endswith('.mgf'): headers = [ - "TITLE", - "SMILES", - "FORMULA", - "PRECURSOR_MZ", - "PRECURSORTYPE", - "COLLISIONENERGY", - "INSTRUMENTTYPE", + 'TITLE', + 'SMILES', + 'FORMULA', + 'PRECURSOR_MZ', + 'PRECURSORTYPE', + 'COLLISIONENERGY', + 'INSTRUMENTTYPE', ] if args.rt: - headers.append("RETENTIONTIME") + headers.append('RETENTIONTIME') if args.ccs: - headers.append("CCS") - headers.append("COMMENT") + headers.append('CCS') + headers.append('COMMENT') mgfWriter.write_mgf( df, path=args.output, write_header=True, headers=headers, header_map={ - "TITLE": "Name", - "PRECURSORTYPE": "Precursor_type", - "INSTRUMENTTYPE": "Instrument_type", - "COLLISIONENERGY": "CE", + 'TITLE': 'Name', + 'PRECURSORTYPE': 'Precursor_type', + 'INSTRUMENTTYPE': 'Instrument_type', + 'COLLISIONENERGY': 'CE', }, annotation=args.annotation, ) else: print( - f"Warning: Unknown output format {args.output}. Writing results to {args.output}.mgf instead." + f'Warning: Unknown output format {args.output}. Writing results to {args.output}.mgf instead.' ) - args.output = args.output + ".mgf" + args.output = args.output + '.mgf' headers = [ - "TITLE", - "SMILES", - "FORMULA", - "PRECURSORTYPE", - "COLLISIONENERGY", - "INSTRUMENTTYPE", + 'TITLE', + 'SMILES', + 'FORMULA', + 'PRECURSORTYPE', + 'COLLISIONENERGY', + 'INSTRUMENTTYPE', ] if args.rt: - headers.append("RETENTIONTIME") + headers.append('RETENTIONTIME') if args.ccs: - headers.append("CCS") - headers.append("COMMENT") + headers.append('CCS') + headers.append('COMMENT') mgfWriter.write_mgf( df, path=args.output, write_header=True, headers=headers, header_map={ - "TITLE": "Name", - "PRECURSORTYPE": "Precursor_type", - "INSTRUMENTTYPE": "Instrument_type", - "COLLISIONENERGY": "CE", + 'TITLE': 'Name', + 'PRECURSORTYPE': 'Precursor_type', + 'INSTRUMENTTYPE': 'Instrument_type', + 'COLLISIONENERGY': 'CE', }, annotation=args.annotation, ) @@ -311,12 +311,12 @@ def prepare_output(args, df, model): def main() -> None: args = parse_args() if args.debug: - print(f"Running fiora prediction with the following parameters: {args}\n") + print(f'Running fiora prediction with the following parameters: {args}\n') # Load model - if args.model == "default": + if args.model == 'default': with resources.as_file( - resources.files("fiora.resources.models").joinpath("fiora_OS_v1.0.0.pt") + resources.files('fiora.resources.models').joinpath('fiora_OS_v1.0.0.pt') ) as model_path: args.model = str(model_path) @@ -324,7 +324,7 @@ def main() -> None: model = FioraModel.load_from_state_dict(args.model) except Exception as exc: raise SystemExit( - f"Error: Failed loading from model from state dict. Caused by: {exc}." + f'Error: Failed loading from model from state dict. Caused by: {exc}.' ) print_model_messages(model.model_params) @@ -343,11 +343,11 @@ def main() -> None: df, invalid_df = build_metabolites(df, model.model_params) if invalid_df.shape[0] > 0: if args.debug: - print("Warning: The following input SMILES could not be read or formatted:") + print('Warning: The following input SMILES could not be read or formatted:') print(invalid_df) else: print( - "Warning: Some SMILES could not be read or formatted. Run with --debug flag for more information." + 'Warning: Some SMILES could not be read or formatted. Run with --debug flag for more information.' ) # Simulate compound fragmentation @@ -355,8 +355,8 @@ def main() -> None: # Prepare Output prepare_output(args, df, model) - print(f"Finished prediction. Exported MS/MS spectra to {args.output}.") + print(f'Finished prediction. Exported MS/MS spectra to {args.output}.') -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/fiora/cli/train.py b/fiora/cli/train.py index 82573cc..6e7edd9 100644 --- a/fiora/cli/train.py +++ b/fiora/cli/train.py @@ -16,8 +16,8 @@ from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder -from fiora.GNN.FioraModel import FioraModel from fiora.GNN.fabric_training import seed_everything, train_fabric_loop +from fiora.GNN.FioraModel import FioraModel from fiora.GNN.Losses import ( GraphwiseKLLoss, GraphwiseKLLossMetric, @@ -27,200 +27,200 @@ WeightedMSEMetric, ) from fiora.IO.LibraryLoader import LibraryLoader +from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM from fiora.MOL.Metabolite import Metabolite from fiora.MOL.MetaboliteIndex import MetaboliteIndex -from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM -RDLogger.DisableLog("rdApp.*") -warnings.filterwarnings("ignore", category=SyntaxWarning) +RDLogger.DisableLog('rdApp.*') +warnings.filterwarnings('ignore', category=SyntaxWarning) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - prog="fiora-train", - description="Train a FIORA model from a preprocessed library CSV.", + prog='fiora-train', + description='Train a FIORA model from a preprocessed library CSV.', ) parser.add_argument( - "-i", - "--input", + '-i', + '--input', required=True, - help="Path to preprocessed CSV containing spectra, metadata, and SMILES.", + help='Path to preprocessed CSV containing spectra, metadata, and SMILES.', ) parser.add_argument( - "-o", - "--output", - default="checkpoint_fiora.best.pt", - help="Output path for best checkpoint (.pt).", + '-o', + '--output', + default='checkpoint_fiora.best.pt', + help='Output path for best checkpoint (.pt).', ) parser.add_argument( - "--model-params", - help="Optional path to a JSON file with base model parameters.", + '--model-params', + help='Optional path to a JSON file with base model parameters.', default=None, ) parser.add_argument( - "--resume", - help="Optional path to a checkpoint to resume from (.pt).", + '--resume', + help='Optional path to a checkpoint to resume from (.pt).', default=None, ) parser.add_argument( - "--device", - default="auto", - help="Device to run on (e.g. cpu, cuda:0). Default: auto.", - ) - parser.add_argument("--epochs", type=int, default=300) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--learning-rate", type=float, default=2e-4) - parser.add_argument("--weight-decay", type=float, default=1e-5) + '--device', + default='auto', + help='Device to run on (e.g. cpu, cuda:0). Default: auto.', + ) + parser.add_argument('--epochs', type=int, default=300) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--learning-rate', type=float, default=2e-4) + parser.add_argument('--weight-decay', type=float, default=1e-5) parser.add_argument( - "--hidden-dimension", + '--hidden-dimension', type=int, default=None, - help="Override model hidden dimension (default from model params).", + help='Override model hidden dimension (default from model params).', ) parser.add_argument( - "--embedding-dimension", + '--embedding-dimension', type=int, default=None, - help="Override embedding dimension (default from model params).", + help='Override embedding dimension (default from model params).', ) parser.add_argument( - "--dense-dim", + '--dense-dim', type=int, default=None, - help="Override dense layer hidden dimension (None keeps current setting).", + help='Override dense layer hidden dimension (None keeps current setting).', ) parser.add_argument( - "--residual-connections", + '--residual-connections', action=argparse.BooleanOptionalAction, default=None, - help="Override residual connections setting.", + help='Override residual connections setting.', ) parser.add_argument( - "--layer-stacking", + '--layer-stacking', action=argparse.BooleanOptionalAction, default=None, - help="Override layer stacking setting.", + help='Override layer stacking setting.', ) parser.add_argument( - "--loss", - choices=["graphwise_kl", "weighted_mse", "weighted_mae", "mse"], - default="graphwise_kl", + '--loss', + choices=['graphwise_kl', 'weighted_mse', 'weighted_mae', 'mse'], + default='graphwise_kl', ) parser.add_argument( - "--precursor-loss-weight", + '--precursor-loss-weight', type=float, default=1.0, - help="Multiplier for precursor positions in fragment loss (1.0 keeps original weighting).", + help='Multiplier for precursor positions in fragment loss (1.0 keeps original weighting).', ) parser.add_argument( - "--y-label", - default="compiled_probsALL", - help="Label to use as training target.", + '--y-label', + default='compiled_probsALL', + help='Label to use as training target.', ) parser.add_argument( - "--with-rt", + '--with-rt', action=argparse.BooleanOptionalAction, default=False, - help="Train RT head if available.", + help='Train RT head if available.', ) parser.add_argument( - "--with-ccs", + '--with-ccs', action=argparse.BooleanOptionalAction, default=False, - help="Train CCS head if available.", + help='Train CCS head if available.', ) - parser.add_argument("--train-val-split", type=float, default=0.8) + parser.add_argument('--train-val-split', type=float, default=0.8) parser.add_argument( - "--split-by-group", + '--split-by-group', action=argparse.BooleanOptionalAction, default=True, - help="Split train/val by group_id (prevents leakage).", + help='Split train/val by group_id (prevents leakage).', ) - parser.add_argument("--group-id-col", default="group_id") - parser.add_argument("--datasplit-col", default="datasplit") - parser.add_argument("--train-label", default="training") - parser.add_argument("--val-label", default="validation") - parser.add_argument("--min-peak-matches", type=int, default=2) + parser.add_argument('--group-id-col', default='group_id') + parser.add_argument('--datasplit-col', default='datasplit') + parser.add_argument('--train-label', default='training') + parser.add_argument('--val-label', default='validation') + parser.add_argument('--min-peak-matches', type=int, default=2) parser.add_argument( - "--ppm", + '--ppm', type=float, default=None, - help="Default ppm tolerance if column missing.", - ) - parser.add_argument("--ppm-col", default="ppm_peak_tolerance") - parser.add_argument("--summary-col", default="summary") - parser.add_argument("--peaks-col", default="peaks") - parser.add_argument("--smiles-col", default="SMILES") - parser.add_argument("--loss-weight-col", default="loss_weight") - parser.add_argument("--max-rows", type=int, default=None) - parser.add_argument("--fragmentation-depth", type=int, default=1) + help='Default ppm tolerance if column missing.', + ) + parser.add_argument('--ppm-col', default='ppm_peak_tolerance') + parser.add_argument('--summary-col', default='summary') + parser.add_argument('--peaks-col', default='peaks') + parser.add_argument('--smiles-col', default='SMILES') + parser.add_argument('--loss-weight-col', default='loss_weight') + parser.add_argument('--max-rows', type=int, default=None) + parser.add_argument('--fragmentation-depth', type=int, default=1) parser.add_argument( - "--use-frag-index", + '--use-frag-index', action=argparse.BooleanOptionalAction, default=True, - help="Use MetaboliteIndex to cache fragmentation trees.", + help='Use MetaboliteIndex to cache fragmentation trees.', ) parser.add_argument( - "--graph-mismatch-policy", - choices=["recompute", "ignore"], - default="recompute", + '--graph-mismatch-policy', + choices=['recompute', 'ignore'], + default='recompute', ) parser.add_argument( - "--precursor-modes", + '--precursor-modes', default=None, - help="Comma-separated precursor modes to encode.", + help='Comma-separated precursor modes to encode.', ) parser.add_argument( - "--instruments", + '--instruments', default=None, - help="Comma-separated instrument types to encode.", + help='Comma-separated instrument types to encode.', ) - parser.add_argument("--ce-upper-limit", type=float, default=100.0) - parser.add_argument("--weight-upper-limit", type=float, default=1000.0) - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--num-workers", type=int, default=0) + parser.add_argument('--ce-upper-limit', type=float, default=100.0) + parser.add_argument('--weight-upper-limit', type=float, default=1000.0) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--num-workers', type=int, default=0) parser.add_argument( - "--pin-memory", + '--pin-memory', action=argparse.BooleanOptionalAction, default=None, - help="Pin host memory for DataLoader (auto: enabled for CUDA).", + help='Pin host memory for DataLoader (auto: enabled for CUDA).', ) - parser.add_argument("--val-every", type=int, default=1) + parser.add_argument('--val-every', type=int, default=1) parser.add_argument( - "--use-validation-mask", + '--use-validation-mask', action=argparse.BooleanOptionalAction, default=False, - help="Use validation mask during validation.", + help='Use validation mask during validation.', ) - parser.add_argument("--validation-mask-name", default="validation_mask") + parser.add_argument('--validation-mask-name', default='validation_mask') parser.add_argument( - "--scheduler", - choices=["plateau", "none"], - default="plateau", + '--scheduler', + choices=['plateau', 'none'], + default='plateau', ) - parser.add_argument("--scheduler-patience", type=int, default=8) - parser.add_argument("--scheduler-factor", type=float, default=0.5) + parser.add_argument('--scheduler-patience', type=int, default=8) + parser.add_argument('--scheduler-factor', type=float, default=0.5) parser.add_argument( - "--rt-metric", + '--rt-metric', action=argparse.BooleanOptionalAction, default=False, - help="Track RT/CCS metrics instead of fragment metrics.", + help='Track RT/CCS metrics instead of fragment metrics.', ) parser.add_argument( - "--index-col", + '--index-col', type=int, default=0, - help="CSV index column (default: 0). Use --no-index-col to disable.", + help='CSV index column (default: 0). Use --no-index-col to disable.', ) parser.add_argument( - "--no-index-col", - action="store_true", - help="Disable index_col when reading CSV.", + '--no-index-col', + action='store_true', + help='Disable index_col when reading CSV.', ) parser.add_argument( - "--history-out", + '--history-out', default=None, - help="Optional path to save training history (.json or .csv).", + help='Optional path to save training history (.json or .csv).', ) return parser.parse_args() @@ -239,9 +239,9 @@ def _parse_dict(val): except Exception: pass # Fallback for python-literal style dict strings. - norm = re.sub(r"\b(?:NaN|nan)\b", "None", text) - norm = re.sub(r"\b(?:Infinity|inf)\b", "1e309", norm) - norm = re.sub(r"\b(?:-Infinity|-inf)\b", "-1e309", norm) + norm = re.sub(r'\b(?:NaN|nan)\b', 'None', text) + norm = re.sub(r'\b(?:Infinity|inf)\b', '1e309', norm) + norm = re.sub(r'\b(?:-Infinity|-inf)\b', '-1e309', norm) try: parsed = ast.literal_eval(norm) return parsed if isinstance(parsed, dict) else None @@ -373,8 +373,8 @@ def _match_peaks_task(task): idx, metabolite, peaks, tol = task if not isinstance(peaks, dict): return idx, False - mz = peaks.get("mz") - intensity = peaks.get("intensity") + mz = peaks.get('mz') + intensity = peaks.get('intensity') if mz is None or intensity is None or len(mz) == 0: return idx, False try: @@ -385,36 +385,36 @@ def _match_peaks_task(task): def _resolve_device(device: str) -> str: - if device == "auto": - return "cuda:0" if torch.cuda.is_available() else "cpu" + if device == 'auto': + return 'cuda:0' if torch.cuda.is_available() else 'cpu' return device def _load_model_params(path: str | None) -> dict: if path is None: return {} - with open(path, "r") as fp: + with open(path, 'r') as fp: return json.load(fp) def _choose_loss(loss_name: str): - if loss_name == "graphwise_kl": - return GraphwiseKLLoss(reduction="mean"), {"kl": GraphwiseKLLossMetric} - if loss_name == "weighted_mse": - return WeightedMSELoss(), {"mse": WeightedMSEMetric} - if loss_name == "weighted_mae": - return WeightedMAELoss(), {"mae": WeightedMAEMetric} - if loss_name == "mse": + if loss_name == 'graphwise_kl': + return GraphwiseKLLoss(reduction='mean'), {'kl': GraphwiseKLLossMetric} + if loss_name == 'weighted_mse': + return WeightedMSELoss(), {'mse': WeightedMSEMetric} + if loss_name == 'weighted_mae': + return WeightedMAELoss(), {'mae': WeightedMAEMetric} + if loss_name == 'mse': return torch.nn.MSELoss(), None - raise ValueError(f"Unknown loss: {loss_name}") + raise ValueError(f'Unknown loss: {loss_name}') def _save_history(history: dict, output_path: str) -> None: - os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) - if output_path.lower().endswith(".csv"): + os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) + if output_path.lower().endswith('.csv'): pd.DataFrame(history).to_csv(output_path, index=False) else: - with open(output_path, "w") as fp: + with open(output_path, 'w') as fp: json.dump(history, fp, indent=2) @@ -431,21 +431,21 @@ def _split_geo_data( if len(geo_data) == 0: return [], [] - if split_by_group and hasattr(geo_data[0], "group_id"): - group_ids = np.array([int(getattr(x, "group_id")) for x in geo_data]) + if split_by_group and hasattr(geo_data[0], 'group_id'): + group_ids = np.array([int(getattr(x, 'group_id')) for x in geo_data]) keys = np.unique(group_ids) if len(train_keys) > 0 and len(val_keys) > 0: train_set = set(int(x) for x in train_keys) val_set = set(int(x) for x in val_keys) - print("Using pre-set train/validation keys") + print('Using pre-set train/validation keys') else: tr, va = train_test_split( keys, test_size=1 - train_val_split, random_state=seed ) train_set = set(int(x) for x in tr) val_set = set(int(x) for x in va) - train_data = [x for x in geo_data if int(getattr(x, "group_id")) in train_set] - val_data = [x for x in geo_data if int(getattr(x, "group_id")) in val_set] + train_data = [x for x in geo_data if int(getattr(x, 'group_id')) in train_set] + val_data = [x for x in geo_data if int(getattr(x, 'group_id')) in val_set] return train_data, val_data train_size = int(len(geo_data) * train_val_split) @@ -461,7 +461,7 @@ def _split_geo_data( def main() -> None: args = parse_args() dev = _resolve_device(args.device) - np.seterr(invalid="ignore") + np.seterr(invalid='ignore') seed_everything(args.seed) index_col = None if args.no_index_col else args.index_col @@ -480,55 +480,55 @@ def main() -> None: # Prepare encoders overwrite_sets = {} if args.instruments: - overwrite_sets["instrument"] = [ - x.strip() for x in args.instruments.split(",") if x.strip() + overwrite_sets['instrument'] = [ + x.strip() for x in args.instruments.split(',') if x.strip() ] if args.precursor_modes: - overwrite_sets["precursor_mode"] = [ - x.strip() for x in args.precursor_modes.split(",") if x.strip() + overwrite_sets['precursor_mode'] = [ + x.strip() for x in args.precursor_modes.split(',') if x.strip() ] if not overwrite_sets: overwrite_sets = None node_encoder = AtomFeatureEncoder( - feature_list=["symbol", "num_hydrogen", "ring_type"] + feature_list=['symbol', 'num_hydrogen', 'ring_type'] ) - bond_encoder = BondFeatureEncoder(feature_list=["bond_type", "ring_type"]) + bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type']) covariate_encoder = CovariateFeatureEncoder( feature_list=[ - "collision_energy", - "molecular_weight", - "precursor_mode", - "instrument", - "element_composition", + 'collision_energy', + 'molecular_weight', + 'precursor_mode', + 'instrument', + 'element_composition', ], sets_overwrite=overwrite_sets, ) rt_encoder = CovariateFeatureEncoder( feature_list=[ - "molecular_weight", - "precursor_mode", - "instrument", - "element_composition", + 'molecular_weight', + 'precursor_mode', + 'instrument', + 'element_composition', ], sets_overwrite=overwrite_sets, ) - covariate_encoder.normalize_features["collision_energy"]["max"] = ( + covariate_encoder.normalize_features['collision_energy']['max'] = ( args.ce_upper_limit ) - covariate_encoder.normalize_features["molecular_weight"]["max"] = ( + covariate_encoder.normalize_features['molecular_weight']['max'] = ( args.weight_upper_limit ) - rt_encoder.normalize_features["molecular_weight"]["max"] = args.weight_upper_limit + rt_encoder.normalize_features['molecular_weight']['max'] = args.weight_upper_limit metadata_key_map = { - "name": ["Name", "NAME", "Title", "TITLE"], - "collision_energy": ["CE", "COLLISION_ENERGY", "CollisionEnergy"], - "instrument": ["Instrument_type", "instrument", "INSTRUMENT_TYPE"], - "precursor_mode": ["Precursor_type", "ADDUCT", "PRECURSORTYPE"], - "precursor_mz": ["PrecursorMZ", "PEPMASS", "PRECURSORMZ"], - "retention_time": ["RETENTIONTIME", "RTINSECONDS", "retention_time"], - "ccs": ["CCS", "ccs"], + 'name': ['Name', 'NAME', 'Title', 'TITLE'], + 'collision_energy': ['CE', 'COLLISION_ENERGY', 'CollisionEnergy'], + 'instrument': ['Instrument_type', 'instrument', 'INSTRUMENT_TYPE'], + 'precursor_mode': ['Precursor_type', 'ADDUCT', 'PRECURSORTYPE'], + 'precursor_mz': ['PrecursorMZ', 'PEPMASS', 'PRECURSORMZ'], + 'retention_time': ['RETENTIONTIME', 'RTINSECONDS', 'retention_time'], + 'ccs': ['CCS', 'ccs'], } # Build metabolites @@ -553,27 +553,27 @@ def main() -> None: _prepare_metabolite_task, metabolite_tasks, args.num_workers ) for idx, mol in _progress_iterator( - metabolite_results, total=len(df), desc="Building graphs" + metabolite_results, total=len(df), desc='Building graphs' ): if mol is None: invalid_rows.append(idx) continue - df.at[idx, "Metabolite"] = mol + df.at[idx, 'Metabolite'] = mol if invalid_rows: df = df.drop(index=invalid_rows) - print(f"Dropped {len(invalid_rows)} invalid rows.") + print(f'Dropped {len(invalid_rows)} invalid rows.') # Fragmentation trees if args.use_frag_index: mindex = MetaboliteIndex() - mindex.index_metabolites(df["Metabolite"]) + mindex.index_metabolites(df['Metabolite']) mindex.create_fragmentation_trees(depth=args.fragmentation_depth) mindex.add_fragmentation_trees_to_metabolite_list( - df["Metabolite"], graph_mismatch_policy=args.graph_mismatch_policy + df['Metabolite'], graph_mismatch_policy=args.graph_mismatch_policy ) else: - df["Metabolite"].apply(lambda x: x.fragment_MOL(depth=args.fragmentation_depth)) + df['Metabolite'].apply(lambda x: x.fragment_MOL(depth=args.fragmentation_depth)) # Match peaks to fragments ppm_default = args.ppm if args.ppm is not None else DEFAULT_PPM @@ -581,7 +581,7 @@ def main() -> None: match_tasks = ( ( idx, - row["Metabolite"], + row['Metabolite'], row.get(args.peaks_col), _resolve_tolerance(row, args.ppm_col, ppm_default), ) @@ -593,16 +593,16 @@ def main() -> None: if match_invalid: df = df.drop(index=match_invalid) - print(f"Dropped {len(match_invalid)} rows with invalid peaks.") + print(f'Dropped {len(match_invalid)} rows with invalid peaks.') - df["num_peak_matches"] = df["Metabolite"].apply( - lambda x: x.match_stats["num_peak_matches"] + df['num_peak_matches'] = df['Metabolite'].apply( + lambda x: x.match_stats['num_peak_matches'] ) if args.min_peak_matches > 0: before = len(df) - df = df[df["num_peak_matches"] >= args.min_peak_matches] + df = df[df['num_peak_matches'] >= args.min_peak_matches] print( - f"Filtered {before - len(df)} rows with < {args.min_peak_matches} peak matches." + f'Filtered {before - len(df)} rows with < {args.min_peak_matches} peak matches.' ) # Train/val split @@ -629,81 +629,81 @@ def main() -> None: # Geometric data geo_data = [] for _, row in df_train.iterrows(): - data = row["Metabolite"].as_geometric_data() + data = row['Metabolite'].as_geometric_data() if args.group_id_col in df_train.columns: try: data.group_id = int(row[args.group_id_col]) except Exception: pass geo_data.append(data) - print(f"Prepared training/validation with {len(geo_data)} data points") + print(f'Prepared training/validation with {len(geo_data)} data points') # Model params default_params = { - "param_tag": "default", - "gnn_type": "RGCNConv", - "depth": 10, - "hidden_dimension": 300, - "residual_connections": False, - "layer_stacking": True, - "embedding_aggregation": "concat", - "embedding_dimension": 300, - "subgraph_features": True, - "pooling_func": "max", - "layer_norm": True, - "dense_layers": 2, - "dense_dim": 500, - "input_dropout": 0.25, - "latent_dropout": 0.25, - "prepare_additional_layers": False, - "rt_supported": False, - "ccs_supported": False, - "version": "x.x.x", + 'param_tag': 'default', + 'gnn_type': 'RGCNConv', + 'depth': 10, + 'hidden_dimension': 300, + 'residual_connections': False, + 'layer_stacking': True, + 'embedding_aggregation': 'concat', + 'embedding_dimension': 300, + 'subgraph_features': True, + 'pooling_func': 'max', + 'layer_norm': True, + 'dense_layers': 2, + 'dense_dim': 500, + 'input_dropout': 0.25, + 'latent_dropout': 0.25, + 'prepare_additional_layers': False, + 'rt_supported': False, + 'ccs_supported': False, + 'version': 'x.x.x', } base_params = _load_model_params(args.model_params) model_params = dict(default_params) model_params.update(base_params) model_params.update( { - "node_feature_layout": node_encoder.feature_numbers, - "edge_feature_layout": bond_encoder.feature_numbers, - "static_feature_dimension": geo_data[0]["static_edge_features"].shape[1], - "static_rt_feature_dimension": geo_data[0]["static_rt_features"].shape[1], - "output_dimension": len(DEFAULT_MODES) * 2, - "atom_features": node_encoder.feature_list, - "setup_features": covariate_encoder.feature_list, - "setup_features_categorical_set": covariate_encoder.categorical_sets, - "rt_features": rt_encoder.feature_list, - "prepare_additional_layers": args.with_rt or args.with_ccs, - "rt_supported": args.with_rt, - "ccs_supported": args.with_ccs, + 'node_feature_layout': node_encoder.feature_numbers, + 'edge_feature_layout': bond_encoder.feature_numbers, + 'static_feature_dimension': geo_data[0]['static_edge_features'].shape[1], + 'static_rt_feature_dimension': geo_data[0]['static_rt_features'].shape[1], + 'output_dimension': len(DEFAULT_MODES) * 2, + 'atom_features': node_encoder.feature_list, + 'setup_features': covariate_encoder.feature_list, + 'setup_features_categorical_set': covariate_encoder.categorical_sets, + 'rt_features': rt_encoder.feature_list, + 'prepare_additional_layers': args.with_rt or args.with_ccs, + 'rt_supported': args.with_rt, + 'ccs_supported': args.with_ccs, } ) if args.hidden_dimension is not None: - model_params["hidden_dimension"] = int(args.hidden_dimension) + model_params['hidden_dimension'] = int(args.hidden_dimension) if args.embedding_dimension is not None: - model_params["embedding_dimension"] = int(args.embedding_dimension) + model_params['embedding_dimension'] = int(args.embedding_dimension) if args.dense_dim is not None: - model_params["dense_dim"] = int(args.dense_dim) + model_params['dense_dim'] = int(args.dense_dim) if args.residual_connections is not None: - model_params["residual_connections"] = bool(args.residual_connections) + model_params['residual_connections'] = bool(args.residual_connections) if args.layer_stacking is not None: - model_params["layer_stacking"] = bool(args.layer_stacking) - if model_params.get("residual_connections", False): + model_params['layer_stacking'] = bool(args.layer_stacking) + if model_params.get('residual_connections', False): if ( - model_params.get("hidden_dimension") - != model_params.get("embedding_dimension") + model_params.get('hidden_dimension') + != model_params.get('embedding_dimension') and args.embedding_dimension is None ): - model_params["embedding_dimension"] = model_params["hidden_dimension"] - if args.dense_dim is None and "dense_dim" not in base_params: + model_params['embedding_dimension'] = model_params['hidden_dimension'] + if args.dense_dim is None and 'dense_dim' not in base_params: # Avoid shape-mismatch in dense residual blocks when using default params. - model_params["dense_dim"] = None + model_params['dense_dim'] = None # Initialize or resume model if args.resume: - state_path = args.resume.replace(".pt", "_state.pt") - params_path = args.resume.replace(".pt", "_params.json") + state_path = args.resume.replace('.pt', '_state.pt') + params_path = args.resume.replace('.pt', '_params.json') if os.path.exists(state_path) and os.path.exists(params_path): model = FioraModel.load_from_state_dict(args.resume) else: @@ -712,12 +712,12 @@ def main() -> None: model = FioraModel(model_params) if (args.with_rt or args.with_ccs) and not model.model_params.get( - "prepare_additional_layers", False + 'prepare_additional_layers', False ): raise RuntimeError( - "Model does not include RT/CCS heads but --with-rt/--with-ccs was set." + 'Model does not include RT/CCS heads but --with-rt/--with-ccs was set.' ) - model.model_params["training_label"] = args.y_label + model.model_params['training_label'] = args.y_label loss_fn, metric_dict = _choose_loss(args.loss) @@ -731,10 +731,10 @@ def main() -> None: val_keys=val_keys, ) has_validation = len(val_data) > 0 - print(f"Train/validation split: {len(train_data)} / {len(val_data)}") + print(f'Train/validation split: {len(train_data)} / {len(val_data)}') output_path = args.output - os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) checkpoints, history = train_fabric_loop( model=model, train_data=train_data, @@ -764,9 +764,9 @@ def main() -> None: ) if args.history_out: _save_history(history, args.history_out) - print(f"Saved training history to {args.history_out}") - print(f"Finished training. Best checkpoint: {checkpoints['file']}") + print(f'Saved training history to {args.history_out}') + print(f'Finished training. Best checkpoint: {checkpoints["file"]}') -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/fiora/visualization/define_colors.py b/fiora/visualization/define_colors.py index a0e0f46..21a00b7 100644 --- a/fiora/visualization/define_colors.py +++ b/fiora/visualization/define_colors.py @@ -1,8 +1,8 @@ import matplotlib import matplotlib.colors +import numpy as np import seaborn as sns from matplotlib import pyplot as plt -import numpy as np from matplotlib.patches import PathPatch @@ -11,109 +11,109 @@ def mix_colors(c1, c2, ratio=1.0): return [((x * ratio) + y) / (ratio + 1) for (x, y) in z] -col_mistle = sns.color_palette("Set3")[0] +col_mistle = sns.color_palette('Set3')[0] col_mistle_dark = (0.5 * col_mistle[0], 0.9 * col_mistle[1], 0.9 * col_mistle[2]) col_mistle_bright = (0.5 * col_mistle[0], 0.9 * col_mistle[1], 0.9 * col_mistle[2]) -col_decoy = sns.color_palette("Set3")[1] -col_spectrast = sns.color_palette(palette="Set3")[ +col_decoy = sns.color_palette('Set3')[1] +col_spectrast = sns.color_palette(palette='Set3')[ 4 ] # (1,1,1) #sns.color_palette(palette="Set3")[4] # col_spectrast = (1,1,1) #sns.color_palette(palette="Set3")[4] -col_xtandem = sns.color_palette(palette="Set3")[3] -col_msf = sns.color_palette(palette="Set3")[3] +col_xtandem = sns.color_palette(palette='Set3')[3] +col_msf = sns.color_palette(palette='Set3')[3] col_olivegrey = (0.65, 0.8, 0.6) -col_st_line = "slateblue" +col_st_line = 'slateblue' -col_i1 = "purple" -col_i1_dot = "violet" +col_i1 = 'purple' +col_i1_dot = 'violet' -palette = sns.color_palette("colorblind") +palette = sns.color_palette('colorblind') color_palette = palette C = { - "green": palette[2], - "orange": palette[1], - "blue": palette[0], - "red": palette[3], - "yellow": palette[8], + 'green': palette[2], + 'orange': palette[1], + 'blue': palette[0], + 'red': palette[3], + 'yellow': palette[8], } -C["g"] = C["green"] -C["o"] = C["orange"] -C["b"] = C["blue"] -C["r"] = C["red"] -C["lightgreen"] = [1.25 * x for x in C["green"]] -C["darkgreen"] = [0.75 * x for x in C["green"]] -C["ivorygreen"] = mix_colors(C["green"], matplotlib.colors.to_rgb("ivory"), ratio=0.5) -C["chocolategreen"] = mix_colors( - C["green"], matplotlib.colors.to_rgb("chocolate"), ratio=1.5 +C['g'] = C['green'] +C['o'] = C['orange'] +C['b'] = C['blue'] +C['r'] = C['red'] +C['lightgreen'] = [1.25 * x for x in C['green']] +C['darkgreen'] = [0.75 * x for x in C['green']] +C['ivorygreen'] = mix_colors(C['green'], matplotlib.colors.to_rgb('ivory'), ratio=0.5) +C['chocolategreen'] = mix_colors( + C['green'], matplotlib.colors.to_rgb('chocolate'), ratio=1.5 ) PRINT_COL = { - "black": "\033[98m", - "blue": "\033[94m", - "green": "\033[92m", - "yellow": "\033[93m", - "red": "\033[91m", - "end": "\033[00m", + 'black': '\033[98m', + 'blue': '\033[94m', + 'green': '\033[92m', + 'yellow': '\033[93m', + 'red': '\033[91m', + 'end': '\033[00m', } ELEMENT_COLORS = { - "C": "#909090", # Carbon (gray) - "O": "#FF0D0D", # Oxygen (red) - "N": "#3050F8", # Nitrogen (blue) - "F": "#98E8F8", # Fluorine (light blue) - "Cl": "#00FF00", # Chlorine (green) - "Br": "#A62929", # Bromine (dark red) - "I": "#940094", # Iodine (purple) - "P": "#FF8000", # Phosphorus (orange) - "S": "#FFFF30", # Sulfur (yellow) - "Si": "#D1D1E0", # Silicon (light gray) - "ANY_RARE": "#FFFFFF", # White for any rare elements + 'C': '#909090', # Carbon (gray) + 'O': '#FF0D0D', # Oxygen (red) + 'N': '#3050F8', # Nitrogen (blue) + 'F': '#98E8F8', # Fluorine (light blue) + 'Cl': '#00FF00', # Chlorine (green) + 'Br': '#A62929', # Bromine (dark red) + 'I': '#940094', # Iodine (purple) + 'P': '#FF8000', # Phosphorus (orange) + 'S': '#FFFF30', # Sulfur (yellow) + 'Si': '#D1D1E0', # Silicon (light gray) + 'ANY_RARE': '#FFFFFF', # White for any rare elements } lightblue = (46, 64, 85) -lightblue_hex = "#75a3d9" +lightblue_hex = '#75a3d9' lightpink = (81, 55, 52) -lightpink_hex = "#cf8c85" +lightpink_hex = '#cf8c85' newpink = (255, 64, 85) # (light blue + Red 255) -newpink_hex = "#ffa3d6" +newpink_hex = '#ffa3d6' newnewpink = (242, 163, 214) # (light blue + Red 255) -newnewpink_hex = "#F2A3D6" +newnewpink_hex = '#F2A3D6' wippinkbutbestsofar = (221, 140, 150) -wippinkbutbestsofar_hex = "#DD8C96" +wippinkbutbestsofar_hex = '#DD8C96' evenbetterpink = (222, 140, 171) -evenbetterpink_hex = "#DE8CAB" +evenbetterpink_hex = '#DE8CAB' maybeevenbetterpink = (222, 148, 172) -maybeevenbetterpink_hex = "#DE94AC" +maybeevenbetterpink_hex = '#DE94AC' abitsofterclassic_hex = (226, 154, 181) -abitsofterclassic_hex = "#E29AB5" +abitsofterclassic_hex = '#E29AB5' -black_hex = "#000000" -lightgreen_hex = "#ACF39D" -wine_hex = "#773344" +black_hex = '#000000' +lightgreen_hex = '#ACF39D' +wine_hex = '#773344' bluepink = sns.color_palette( [lightblue_hex, lightpink_hex, black_hex, lightgreen_hex, wine_hex], as_cmap=True ) bluepink_grad = sns.diverging_palette( - 17.7, 245.8, s=75, l=50, sep=1, n=6, center="light", as_cmap=True + 17.7, 245.8, s=75, l=50, sep=1, n=6, center='light', as_cmap=True ) bluepink_grad8 = sns.diverging_palette( - 17.7, 245.8, s=75, l=50, sep=1, n=8, center="light", as_cmap=False + 17.7, 245.8, s=75, l=50, sep=1, n=8, center='light', as_cmap=False ) -tri_palette = ["gray", bluepink[0], bluepink[1]] +tri_palette = ['gray', bluepink[0], bluepink[1]] def magma(steps): - return sns.color_palette("magma_r", steps) + return sns.color_palette('magma_r', steps) # @@ -123,19 +123,19 @@ def magma(steps): def define_figure_style(style: str, palette_steps=8): # Define figure styles - if "magma-white": - color_palette = sns.color_palette("magma_r", palette_steps) + if 'magma-white': + color_palette = sns.color_palette('magma_r', palette_steps) sns.set_theme( - style="whitegrid", + style='whitegrid', rc={ - "axes.edgecolor": "black", - "ytick.left": True, - "xtick.bottom": True, - "xtick.color": "black", - "axes.spines.bottom": True, - "axes.spines.right": True, - "axes.spines.top": True, - "axes.spines.left": True, + 'axes.edgecolor': 'black', + 'ytick.left': True, + 'xtick.bottom': True, + 'xtick.color': 'black', + 'axes.spines.bottom': True, + 'axes.spines.right': True, + 'axes.spines.top': True, + 'axes.spines.left': True, }, ) return color_palette @@ -143,32 +143,32 @@ def define_figure_style(style: str, palette_steps=8): def set_theme(): sns.set_theme( - style="darkgrid", + style='darkgrid', rc={ - "axes.edgecolor": "black", - "ytick.left": True, - "xtick.bottom": True, - "xtick.color": "black", - "axes.spines.bottom": True, - "axes.spines.right": False, - "axes.spines.top": False, - "axes.spines.left": True, + 'axes.edgecolor': 'black', + 'ytick.left': True, + 'xtick.bottom': True, + 'xtick.color': 'black', + 'axes.spines.bottom': True, + 'axes.spines.right': False, + 'axes.spines.top': False, + 'axes.spines.left': True, }, ) def set_light_theme(): sns.set_theme( - style="whitegrid", + style='whitegrid', rc={ - "axes.edgecolor": "black", - "ytick.left": True, - "xtick.bottom": True, - "xtick.color": "black", - "axes.spines.bottom": True, - "axes.spines.right": True, - "axes.spines.top": True, - "axes.spines.left": True, + 'axes.edgecolor': 'black', + 'ytick.left': True, + 'xtick.bottom': True, + 'xtick.color': 'black', + 'axes.spines.bottom': True, + 'axes.spines.right': True, + 'axes.spines.top': True, + 'axes.spines.left': True, }, ) @@ -179,16 +179,16 @@ def reset_matplotlib(): def set_all_font_sizes(size): zs = [ - "font.size", - "axes.labelsize", - "axes.titlesize", - "legend.fontsize", - "xtick.labelsize", - "xtick.major.size", - "xtick.minor.size", - "ytick.labelsize", - "ytick.major.size", - "ytick.minor.size", + 'font.size', + 'axes.labelsize', + 'axes.titlesize', + 'legend.fontsize', + 'xtick.labelsize', + 'xtick.major.size', + 'xtick.minor.size', + 'ytick.labelsize', + 'ytick.major.size', + 'ytick.minor.size', ] for z in zs: diff --git a/fiora/visualization/inspect_mgf_file.py b/fiora/visualization/inspect_mgf_file.py index 26ca4c9..45d7a0c 100644 --- a/fiora/visualization/inspect_mgf_file.py +++ b/fiora/visualization/inspect_mgf_file.py @@ -1,17 +1,18 @@ -import mgfReader import argparse + import matplotlib.pyplot as plt -import seaborn as sns +import mgfReader import pandas as pd +import seaborn as sns parser = argparse.ArgumentParser() parser.add_argument( - "-v", "--verbose", help="increase output verbosity", action="store_true" + '-v', '--verbose', help='increase output verbosity', action='store_true' ) parser.add_argument( - "-i", - "--infile", - help="path/file.mgf to the search file, which will be inspected", + '-i', + '--infile', + help='path/file.mgf to the search file, which will be inspected', type=str, required=True, ) @@ -20,36 +21,36 @@ # color_palette = "Set3" -color_palette = sns.color_palette("magma_r", 15) +color_palette = sns.color_palette('magma_r', 15) df = mgfReader.read(args.infile) -print("Total number of spectra: %s" % df.shape[0]) +print(f'Total number of spectra: {df.shape[0]}') # CHARGE -df["CHARGE"] = df["CHARGE"].apply(str) -df["CHARGE"] = pd.Categorical(df["CHARGE"], sorted(df.CHARGE.unique())) +df['CHARGE'] = df['CHARGE'].apply(str) +df['CHARGE'] = pd.Categorical(df['CHARGE'], sorted(df.CHARGE.unique())) ax = sns.countplot( - data=df, x="CHARGE", palette=color_palette, edgecolor="black" + data=df, x='CHARGE', palette=color_palette, edgecolor='black' ) # order=df['CHARGE'].value_counts().iloc[:10].index) -plt.title("MS/MS charge distribution") +plt.title('MS/MS charge distribution') plt.show() # Precursor m/z -df["precursor_mz"] = df["PEPMASS"].apply(lambda x: float(x.split(" ")[0])) +df['precursor_mz'] = df['PEPMASS'].apply(lambda x: float(x.split(' ')[0])) -sns.boxplot(data=df, y="precursor_mz", x="CHARGE", palette=color_palette) -plt.title("MS/MS precursor mz range over charge") +sns.boxplot(data=df, y='precursor_mz', x='CHARGE', palette=color_palette) +plt.title('MS/MS precursor mz range over charge') plt.show() # Num of peaks -df["num_peaks"] = df["peaks"].apply(lambda p: len(p["mz"])) +df['num_peaks'] = df['peaks'].apply(lambda p: len(p['mz'])) -sns.boxplot(data=df, y="num_peaks", x="CHARGE", palette=color_palette) -plt.title("MS/MS number of peaks per spectrum over charge") +sns.boxplot(data=df, y='num_peaks', x='CHARGE', palette=color_palette) +plt.title('MS/MS number of peaks per spectrum over charge') plt.show() diff --git a/fiora/visualization/plot_spectrum.py b/fiora/visualization/plot_spectrum.py index 424f254..40eae92 100644 --- a/fiora/visualization/plot_spectrum.py +++ b/fiora/visualization/plot_spectrum.py @@ -1,34 +1,37 @@ import argparse + +import matplotlib.pyplot as plt import numpy as np +import spectrum_visualizer as sv from define_colors import * +from pyteomics import pylab_aux as pa +from pyteomics import usi + import fiora.IO.mgfReader as mgfReader import fiora.IO.mspReader as mspReader -import spectrum_visualizer as sv -from pyteomics import pylab_aux as pa, usi -import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument( - "-f", - "--file1", - help="file where spectrum is contained (.mgf or .msp)", + '-f', + '--file1', + help='file where spectrum is contained (.mgf or .msp)', type=str, - default="/home/ynowatzk/data/9MM/mgf/9MM_FASP.mgf", + default='/home/ynowatzk/data/9MM/mgf/9MM_FASP.mgf', ) parser.add_argument( - "-n", "--name1", help="exact name of spectrum", type=str, required=True + '-n', '--name1', help='exact name of spectrum', type=str, required=True ) parser.add_argument( - "-f2", "--file2", help="file where lower spectrum is found", type=str + '-f2', '--file2', help='file where lower spectrum is found', type=str ) -parser.add_argument("-n2", "--name2", help="exact name of lower spectrum", type=str) -parser.add_argument("-o", "--out", help="output file", type=str) +parser.add_argument('-n2', '--name2', help='exact name of lower spectrum', type=str) +parser.add_argument('-o', '--out', help='output file', type=str) # parser.add_argument("-a", "--annotate", help="perform spectrum annotation", action="store_true", default=False) # parser.add_argument("-p", "--peptide", help="peptide", type=str, default="None") # parser.add_argument("-c", "--charge", help="charge", type=int, default=0) -parser.add_argument("--fontsize", help="font size of the text", type=int) +parser.add_argument('--fontsize', help='font size of the text', type=int) args = parser.parse_args() @@ -44,12 +47,12 @@ def read_spectrum_from_file(file, name): - if file.endswith(".mgf"): + if file.endswith('.mgf'): return mgfReader.get_spectrum_by_name(file, name) - elif file.endswith(".msp"): + elif file.endswith('.msp'): return mspReader.get_spectrum_by_name(file, name) else: - print("UNKNOWN FILE EXTENSION:\n", file) + print('UNKNOWN FILE EXTENSION:\n', file) exit(1) @@ -67,7 +70,7 @@ def read_spectrum_from_file(file, name): sv.plot_spectrum( s1, s2, - title=args.name1 + " matched by " + args.name2.split("/")[0], + title=args.name1 + ' matched by ' + args.name2.split('/')[0], out=args.out, ) # ,annotate=args.annotate, peptide=args.peptide, charge=args.charge, font_size=args.fontsize) else: diff --git a/fiora/visualization/spectrum_visualizer.py b/fiora/visualization/spectrum_visualizer.py index 396f789..9cdb812 100644 --- a/fiora/visualization/spectrum_visualizer.py +++ b/fiora/visualization/spectrum_visualizer.py @@ -1,12 +1,15 @@ +from typing import Dict, List + import matplotlib.pyplot as plt -import spectrum_utils.plot as sup -import spectrum_utils.spectrum as sus -import seaborn as sns -from pyteomics import pylab_aux as pa, usi import pandas as pd +import seaborn as sns import spectrum_utils.fragment_annotation as fa +import spectrum_utils.plot as sup +import spectrum_utils.spectrum as sus +from pyteomics import pylab_aux as pa +from pyteomics import usi + from fiora.visualization.define_colors import * -from typing import Dict, List # From spectrum utils issue https://github.com/bittremieux/spectrum_utils/issues/56 @@ -16,7 +19,7 @@ def get_theoretical_fragments( ): fragments_masses = [] for mod in proteoform.modifications: - fragment = fa.FragmentAnnotation(ion_type="w", charge=1) + fragment = fa.FragmentAnnotation(ion_type='w', charge=1) mass = mod.source[0].mass fragments_masses.append((fragment, mass)) return fragments_masses @@ -25,9 +28,9 @@ def get_theoretical_fragments( def set_custom_annotation(): # Use the custom function to annotate the fragments fa.get_theoretical_fragments = get_theoretical_fragments - fa._supported_ions += "w" + fa._supported_ions += 'w' # Set peak color for custom ion - sup.colors["w"] = lightblue_hex + sup.colors['w'] = lightblue_hex def set_default_peak_color(color): @@ -42,10 +45,10 @@ def annotate_and_plot( # Instantiate Spectrum and annotate with proforma string format (e.g. X[+9.99] ) spectrum = sus.MsmsSpectrum( - "None", 0, 1, spectrum["peaks"]["mz"], spectrum["peaks"]["intensity"] + 'None', 0, 1, spectrum['peaks']['mz'], spectrum['peaks']['intensity'] ) - x_string = "".join([f"X[+{mz}]" for mz in sorted(mz_fragments)]) - spectrum.annotate_proforma(x_string, ppm_tolerance, "ppm") + x_string = ''.join([f'X[+{mz}]' for mz in sorted(mz_fragments)]) + spectrum.annotate_proforma(x_string, ppm_tolerance, 'ppm') # Find ax and plot if not ax: @@ -75,7 +78,7 @@ def plot_spectrum( color=None, ): top_spectrum = sus.MsmsSpectrum( - "None", 0, charge, spectrum["peaks"]["mz"], spectrum["peaks"]["intensity"] + 'None', 0, charge, spectrum['peaks']['mz'], spectrum['peaks']['intensity'] ) if color: set_default_peak_color(color) @@ -84,35 +87,35 @@ def plot_spectrum( # spectrum.set_mz_range(min_mz=0, max_mz=2000) if second_spectrum is not None: bottom_spectrum = sus.MsmsSpectrum( - "None", + 'None', 0, charge, - second_spectrum["peaks"]["mz"], - second_spectrum["peaks"]["intensity"], + second_spectrum['peaks']['mz'], + second_spectrum['peaks']['intensity'], ) if highlight_matches: set_custom_annotation() - x_string = "".join( - [f"X[+{mz}]" for mz in sorted(second_spectrum["peaks"]["mz"])] + x_string = ''.join( + [f'X[+{mz}]' for mz in sorted(second_spectrum['peaks']['mz'])] ) - top_spectrum.annotate_proforma(x_string, ppm_tolerance, "ppm") - x_string = "".join([f"X[+{mz}]" for mz in sorted(spectrum["peaks"]["mz"])]) - bottom_spectrum.annotate_proforma(x_string, ppm_tolerance, "ppm") + top_spectrum.annotate_proforma(x_string, ppm_tolerance, 'ppm') + x_string = ''.join([f'X[+{mz}]' for mz in sorted(spectrum['peaks']['mz'])]) + bottom_spectrum.annotate_proforma(x_string, ppm_tolerance, 'ppm') if facet_plot: sup.facet( spec_top=top_spectrum, spec_mass_errors=top_spectrum, spec_bottom=bottom_spectrum, - mass_errors_kws={"plot_unknown": False}, + mass_errors_kws={'plot_unknown': False}, ) else: # mirror plot sup.mirror( spec_top=top_spectrum, spec_bottom=bottom_spectrum, ax=ax, - spectrum_kws={"grid": with_grid}, + spectrum_kws={'grid': with_grid}, ) if with_grid: @@ -120,15 +123,15 @@ def plot_spectrum( else: sns.despine(ax=ax) if second_spectrum is not None: - ax.spines["bottom"].set_position(("outward", 10)) + ax.spines['bottom'].set_position(('outward', 10)) # Single spectrum else: if highlight_matches and mz_matches: set_custom_annotation() - x_string = "".join([f"X[+{mz}]" for mz in sorted(mz_matches)]) - top_spectrum.annotate_proforma(x_string, ppm_tolerance, "ppm") + x_string = ''.join([f'X[+{mz}]' for mz in sorted(mz_matches)]) + top_spectrum.annotate_proforma(x_string, ppm_tolerance, 'ppm') sup.spectrum(top_spectrum, grid=with_grid, ax=ax) if with_grid: @@ -147,13 +150,13 @@ def plot_spectrum( def plot_vector_spectrum( - vec1, vec2, ax=None, title=None, y_label="probability", names=None + vec1, vec2, ax=None, title=None, y_label='probability', names=None ): v1 = pd.DataFrame( - {"range": names if names else range(len(vec1)), "prob": vec1, "group": "prob"} + {'range': names if names else range(len(vec1)), 'prob': vec1, 'group': 'prob'} ) v2 = pd.DataFrame( - {"range": names if names else range(len(vec2)), "prob": vec2, "group": "pred"} + {'range': names if names else range(len(vec2)), 'prob': vec2, 'group': 'pred'} ) V = pd.concat([v1, v2]) @@ -162,13 +165,13 @@ def plot_vector_spectrum( sns.barplot( ax=ax, data=V, - y="prob", - x="range", - edgecolor="black", - hue="group", + y='prob', + x='range', + edgecolor='black', + hue='group', linewidth=1.5, ) - ax.set_xlabel("") + ax.set_xlabel('') ax.set_ylabel(y_label) ax.set_title(title) return ax diff --git a/pyproject.toml b/pyproject.toml index 6e88b0e..5e589ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,13 @@ dependencies = [ "scikit-learn", ] +[project.optional-dependencies] +dev = [ + "pre-commit", + "ruff==0.15.11", + "pytest", +] + [project.scripts] fiora-predict = "fiora.cli.predict:main" fiora-train = "fiora.cli.train:main" @@ -40,7 +47,7 @@ fiora-model-info = "fiora.cli.model_info:main" [tool.setuptools] include-package-data = true -script-files = ["scripts/fiora-predict", "scripts/fiora-train", "scripts/fiora-eval", "scripts/fiora-model-info"] # like setup.py `scripts=[...]` (not console entry points) +script-files = ["scripts/fiora-predict", "scripts/fiora-train", "scripts/fiora-eval", "scripts/fiora-model-info"] [tool.setuptools.packages.find] include = ["fiora", "fiora.*"] @@ -54,3 +61,82 @@ include = ["fiora", "fiora.*"] "fiora_OS_v1.0.0_state.pt", "fiora_OS_v1.0.0_params.json", ] + +[tool.ruff] +required-version = "==0.15.11" +include = ["fiora/**/*.py", "tests/**/*.py"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "dependencies", +] +line-length = 88 +indent-width = 4 + +[tool.ruff.lint] +select = [ + "E", + "PL", + "F", + "UP", + "I", +] +ignore = [ + "F401", + "F403", + "F405", + "F841", + "E501", + "E701", + "E711", + "E731", + "E741", + "E402", + "UP006", + "UP008", + "UP015", + "UP035", + "PLR0402", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + "PLR1704", + "PLR1711", + "PLR1714", + "PLR1722", + "PLR2004", + "PLR5501", + "PLC0207", + "PLC0415", + "PLW0603", + "PLW1641", + "PLW2901", +] +fixable = ["ALL"] +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +quote-style = "single" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/tests/test_fiora_eval.py b/tests/test_fiora_eval.py index b0af877..e2ce819 100644 --- a/tests/test_fiora_eval.py +++ b/tests/test_fiora_eval.py @@ -1,49 +1,47 @@ +import contextlib import io import os import sys - import unittest -from unittest.mock import patch -import contextlib - -from importlib.util import spec_from_loader, module_from_spec from importlib.machinery import SourceFileLoader +from importlib.util import module_from_spec, spec_from_loader +from unittest.mock import patch spec = spec_from_loader( - "fiora-eval", - SourceFileLoader("fiora-eval", os.getcwd() + "/scripts/fiora-eval"), + 'fiora-eval', + SourceFileLoader('fiora-eval', os.getcwd() + '/scripts/fiora-eval'), ) fiora_eval = module_from_spec(spec) spec.loader.exec_module(fiora_eval) -sys.modules["fiora_eval"] = fiora_eval +sys.modules['fiora_eval'] = fiora_eval class TestFioraEval(unittest.TestCase): def test_missing_args(self): f = io.StringIO() - with patch("sys.argv", ["main"]): + with patch('sys.argv', ['main']): with self.assertRaises(SystemExit) as cm, contextlib.redirect_stderr(f): fiora_eval.main() self.assertEqual(cm.exception.code, 2) - self.assertTrue(f.getvalue().startswith("usage:")) + self.assertTrue(f.getvalue().startswith('usage:')) def test_help(self): f = io.StringIO() - with patch("sys.argv", ["main", "-h"]): + with patch('sys.argv', ['main', '-h']): with self.assertRaises(SystemExit) as cm, contextlib.redirect_stdout(f): fiora_eval.main() self.assertEqual(cm.exception.code, 0) - self.assertTrue(f.getvalue().startswith("usage:")) - self.assertTrue("--model MODEL" in f.getvalue()) - self.assertTrue("--splits SPLITS" in f.getvalue()) + self.assertTrue(f.getvalue().startswith('usage:')) + self.assertTrue('--model MODEL' in f.getvalue()) + self.assertTrue('--splits SPLITS' in f.getvalue()) -if __name__ == "__main__": +if __name__ == '__main__': suite = unittest.TestSuite() suite.addTests( [ - TestFioraEval("test_missing_args"), - TestFioraEval("test_help"), + TestFioraEval('test_missing_args'), + TestFioraEval('test_help'), ] ) runner = unittest.TextTestRunner() diff --git a/tests/test_fiora_model_info.py b/tests/test_fiora_model_info.py index 2387758..1f7972e 100644 --- a/tests/test_fiora_model_info.py +++ b/tests/test_fiora_model_info.py @@ -1,49 +1,47 @@ +import contextlib import io import os import sys - import unittest -from unittest.mock import patch -import contextlib - -from importlib.util import spec_from_loader, module_from_spec from importlib.machinery import SourceFileLoader +from importlib.util import module_from_spec, spec_from_loader +from unittest.mock import patch spec = spec_from_loader( - "fiora-model-info", - SourceFileLoader("fiora-model-info", os.getcwd() + "/scripts/fiora-model-info"), + 'fiora-model-info', + SourceFileLoader('fiora-model-info', os.getcwd() + '/scripts/fiora-model-info'), ) fiora_model_info = module_from_spec(spec) spec.loader.exec_module(fiora_model_info) -sys.modules["fiora_model_info"] = fiora_model_info +sys.modules['fiora_model_info'] = fiora_model_info class TestFioraModelInfo(unittest.TestCase): def test_missing_args(self): f = io.StringIO() - with patch("sys.argv", ["main"]): + with patch('sys.argv', ['main']): with self.assertRaises(SystemExit) as cm, contextlib.redirect_stderr(f): fiora_model_info.main() self.assertEqual(cm.exception.code, 2) - self.assertTrue(f.getvalue().startswith("usage:")) + self.assertTrue(f.getvalue().startswith('usage:')) def test_help(self): f = io.StringIO() - with patch("sys.argv", ["main", "-h"]): + with patch('sys.argv', ['main', '-h']): with self.assertRaises(SystemExit) as cm, contextlib.redirect_stdout(f): fiora_model_info.main() self.assertEqual(cm.exception.code, 0) - self.assertTrue(f.getvalue().startswith("usage:")) - self.assertTrue("--model MODEL" in f.getvalue()) - self.assertTrue("--as-json" in f.getvalue()) + self.assertTrue(f.getvalue().startswith('usage:')) + self.assertTrue('--model MODEL' in f.getvalue()) + self.assertTrue('--as-json' in f.getvalue()) -if __name__ == "__main__": +if __name__ == '__main__': suite = unittest.TestSuite() suite.addTests( [ - TestFioraModelInfo("test_missing_args"), - TestFioraModelInfo("test_help"), + TestFioraModelInfo('test_missing_args'), + TestFioraModelInfo('test_help'), ] ) runner = unittest.TextTestRunner() diff --git a/tests/test_fiora_predict.py b/tests/test_fiora_predict.py index 74c4291..7b16bd9 100644 --- a/tests/test_fiora_predict.py +++ b/tests/test_fiora_predict.py @@ -1,33 +1,33 @@ +import contextlib import io import os import sys -import numpy as np - import unittest +from importlib.machinery import SourceFileLoader + +## Importing fiora predict (from executable) +from importlib.util import module_from_spec, spec_from_loader from unittest.mock import patch -import contextlib + +import numpy as np ## Fiora imports import fiora.IO.mgfReader as mgfReader from fiora.MS.spectral_scores import spectral_cosine -## Importing fiora predict (from executable) -from importlib.util import spec_from_loader, module_from_spec -from importlib.machinery import SourceFileLoader - spec = spec_from_loader( - "fiora-predict", - SourceFileLoader("fiora-predict", os.getcwd() + "/scripts/fiora-predict"), + 'fiora-predict', + SourceFileLoader('fiora-predict', os.getcwd() + '/scripts/fiora-predict'), ) fiora_predict = module_from_spec(spec) spec.loader.exec_module(fiora_predict) -sys.modules["fiora_predict"] = fiora_predict +sys.modules['fiora_predict'] = fiora_predict class TestFioraPredict(unittest.TestCase): @classmethod def setUpClass(cls): - cls.temp_path = "temp_spec.mgf" + cls.temp_path = 'temp_spec.mgf' @classmethod def tearDownClass(cls): @@ -36,68 +36,68 @@ def tearDownClass(cls): def test_missing_args(self): f = io.StringIO() - with patch("sys.argv", ["main"]): + with patch('sys.argv', ['main']): with self.assertRaises(SystemExit) as cm, contextlib.redirect_stderr(f): fiora_predict.main() self.assertEqual(cm.exception.code, 2) - self.assertTrue(f.getvalue().startswith("usage:")) + self.assertTrue(f.getvalue().startswith('usage:')) def test_help(self): f = io.StringIO() - with patch("sys.argv", ["main", "-h"]): + with patch('sys.argv', ['main', '-h']): with self.assertRaises(SystemExit) as cm, contextlib.redirect_stdout(f): fiora_predict.main() self.assertEqual(cm.exception.code, 0) - self.assertTrue(f.getvalue().startswith("usage:")) - self.assertTrue("-h, --help" in f.getvalue()) - self.assertTrue("show this help message and exit" in f.getvalue()) + self.assertTrue(f.getvalue().startswith('usage:')) + self.assertTrue('-h, --help' in f.getvalue()) + self.assertTrue('show this help message and exit' in f.getvalue()) def test_dummy(self): - self.assertEqual("fiora".upper(), "FIORA") + self.assertEqual('fiora'.upper(), 'FIORA') def test_model_cpu(self): f = io.StringIO() with patch( - "sys.argv", - ["main", "-i", "examples/example_input.csv", "-o", self.temp_path], + 'sys.argv', + ['main', '-i', 'examples/example_input.csv', '-o', self.temp_path], ): with contextlib.redirect_stdout(f): fiora_predict.main() - self.assertIn("Finished prediction.", f.getvalue()) + self.assertIn('Finished prediction.', f.getvalue()) self.assertTrue(os.path.exists(self.temp_path)) def test_model_output_integrity(self): - expected_output = "examples/expected_output.mgf" + expected_output = 'examples/expected_output.mgf' df_expected = mgfReader.read(expected_output, as_df=True) df_new = mgfReader.read(self.temp_path, as_df=True) columns = [ - "TITLE", - "SMILES", - "PRECURSORTYPE", - "COLLISIONENERGY", - "INSTRUMENTTYPE", + 'TITLE', + 'SMILES', + 'PRECURSORTYPE', + 'COLLISIONENERGY', + 'INSTRUMENTTYPE', ] self.assertDictEqual(df_expected[columns].to_dict(), df_new[columns].to_dict()) for i, data in df_expected.iterrows(): - peaks_expected = data["peaks"] - peaks_new = df_new.at[i, "peaks"] + peaks_expected = data['peaks'] + peaks_new = df_new.at[i, 'peaks'] cosine = spectral_cosine(peaks_expected, peaks_new, transform=np.sqrt) self.assertGreater(cosine, 0.99) -if __name__ == "__main__": +if __name__ == '__main__': # unittest.main() suite = unittest.TestSuite() suite.addTests( [ - TestFioraPredict("test_dummy"), - TestFioraPredict("test_help"), - TestFioraPredict("test_missing_args"), - TestFioraPredict("test_model_cpu"), - TestFioraPredict("test_model_output_integrity"), + TestFioraPredict('test_dummy'), + TestFioraPredict('test_help'), + TestFioraPredict('test_missing_args'), + TestFioraPredict('test_model_cpu'), + TestFioraPredict('test_model_output_integrity'), ] ) diff --git a/tests/test_mol_core.py b/tests/test_mol_core.py index d697ae9..a3d3a7b 100644 --- a/tests/test_mol_core.py +++ b/tests/test_mol_core.py @@ -8,26 +8,25 @@ from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder +from fiora.MOL.collision_energy import NCE_to_eV, align_CE, nce_instruments from fiora.MOL.Metabolite import Metabolite from fiora.MOL.MetaboliteIndex import MetaboliteIndex -from fiora.MOL.collision_energy import NCE_to_eV, align_CE, nce_instruments - SAMPLE_SIZE = 100 SPECTRA_FIXTURE_PATH = ( - Path(__file__).resolve().parent / "data" / "mol_core_100_spectra.jsonl" + Path(__file__).resolve().parent / 'data' / 'mol_core_100_spectra.jsonl' ) -RDLogger.DisableLog("rdApp.*") +RDLogger.DisableLog('rdApp.*') -@pytest.fixture(scope="module") +@pytest.fixture(scope='module') def sample_spectra(): if not SPECTRA_FIXTURE_PATH.exists(): - pytest.skip(f"Missing test fixture: {SPECTRA_FIXTURE_PATH}") + pytest.skip(f'Missing test fixture: {SPECTRA_FIXTURE_PATH}') records = [] - with SPECTRA_FIXTURE_PATH.open("r", encoding="utf-8") as handle: + with SPECTRA_FIXTURE_PATH.open('r', encoding='utf-8') as handle: for line in handle: line = line.strip() if not line: @@ -35,73 +34,71 @@ def sample_spectra(): records.append(json.loads(line)) if len(records) != SAMPLE_SIZE: - pytest.skip( - f"Expected {SAMPLE_SIZE} records in fixture, found {len(records)}" - ) + pytest.skip(f'Expected {SAMPLE_SIZE} records in fixture, found {len(records)}') return records -@pytest.fixture(scope="module") +@pytest.fixture(scope='module') def prepared_metabolites(sample_spectra): node_encoder = AtomFeatureEncoder( - feature_list=["symbol", "num_hydrogen", "ring_type"] + feature_list=['symbol', 'num_hydrogen', 'ring_type'] ) - bond_encoder = BondFeatureEncoder(feature_list=["bond_type", "ring_type"]) + bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type']) setup_encoder = CovariateFeatureEncoder( feature_list=[ - "collision_energy", - "molecular_weight", - "precursor_mode", - "instrument", - "element_composition", + 'collision_energy', + 'molecular_weight', + 'precursor_mode', + 'instrument', + 'element_composition', ] ) rt_encoder = CovariateFeatureEncoder( feature_list=[ - "molecular_weight", - "precursor_mode", - "instrument", - "element_composition", + 'molecular_weight', + 'precursor_mode', + 'instrument', + 'element_composition', ] ) prepared = [] for row in sample_spectra: - metabolite = Metabolite(row["SMILES"]) + metabolite = Metabolite(row['SMILES']) metabolite.create_molecular_structure_graph() metabolite.compute_graph_attributes( node_encoder=node_encoder, bond_encoder=bond_encoder ) - metabolite.add_metadata(dict(row["summary"]), setup_encoder, rt_encoder) - prepared.append({"metabolite": metabolite, "peaks": row["peaks"]}) + metabolite.add_metadata(dict(row['summary']), setup_encoder, rt_encoder) + prepared.append({'metabolite': metabolite, 'peaks': row['peaks']}) return prepared -@pytest.fixture(scope="module") +@pytest.fixture(scope='module') def indexed_metabolites(prepared_metabolites): - metabolites = [entry["metabolite"] for entry in prepared_metabolites] + metabolites = [entry['metabolite'] for entry in prepared_metabolites] index = MetaboliteIndex() index.index_metabolites(metabolites) index.create_fragmentation_trees(depth=1) mismatches = index.add_fragmentation_trees_to_metabolite_list(metabolites) - return {"index": index, "metabolites": metabolites, "mismatches": mismatches} + return {'index': index, 'metabolites': metabolites, 'mismatches': mismatches} def test_load_100_spectra_fixture(sample_spectra): assert len(sample_spectra) == SAMPLE_SIZE first = sample_spectra[0] - assert {"SMILES", "peaks", "summary", "group_id"}.issubset(first.keys()) - assert len(first["peaks"]["mz"]) == len(first["peaks"]["intensity"]) - assert len(first["peaks"]["mz"]) > 0 - assert "collision_energy" in first["summary"] - assert "precursor_mode" in first["summary"] + assert {'SMILES', 'peaks', 'summary', 'group_id'}.issubset(first.keys()) + assert len(first['peaks']['mz']) == len(first['peaks']['intensity']) + assert len(first['peaks']['mz']) > 0 + assert 'collision_energy' in first['summary'] + assert 'precursor_mode' in first['summary'] def test_metabolite_graph_building(prepared_metabolites): assert len(prepared_metabolites) == SAMPLE_SIZE for entry in prepared_metabolites: - metabolite = entry["metabolite"] + metabolite = entry['metabolite'] assert metabolite.edges.shape[1] == 2 assert len(metabolite.edges_as_tuples) == metabolite.edges.shape[0] assert metabolite.node_features.shape[0] == metabolite.Graph.number_of_nodes() @@ -113,9 +110,9 @@ def test_metabolite_graph_building(prepared_metabolites): def test_metabolite_index_and_fragmentation_trees(indexed_metabolites): - index = indexed_metabolites["index"] - metabolites = indexed_metabolites["metabolites"] - mismatches = indexed_metabolites["mismatches"] + index = indexed_metabolites['index'] + metabolites = indexed_metabolites['metabolites'] + mismatches = indexed_metabolites['mismatches'] assert len(mismatches) == 0 assert 0 < index.get_number_of_metabolites() <= SAMPLE_SIZE @@ -131,15 +128,15 @@ def test_peak_matching_and_geometric_export(prepared_metabolites, indexed_metabo matched = 0 for entry in prepared_metabolites: - metabolite = entry["metabolite"] - peaks = entry["peaks"] - mz_list = peaks["mz"] - int_list = peaks["intensity"] + metabolite = entry['metabolite'] + peaks = entry['peaks'] + mz_list = peaks['mz'] + int_list = peaks['intensity'] metabolite.match_fragments_to_peaks(mz_list, int_list) geom = metabolite.as_geometric_data() - assert metabolite.match_stats["num_peaks"] == len(mz_list) + assert metabolite.match_stats['num_peaks'] == len(mz_list) assert ( geom.compiled_probsALL.shape[0] == metabolite.edge_count_matrix.numel() + 2 ) @@ -153,14 +150,14 @@ def test_peak_matching_and_geometric_export(prepared_metabolites, indexed_metabo def test_edge_count_cols_helper(): - mode_map = {"[M-H]-": 0, "[M+H]+": 1} + mode_map = {'[M-H]-': 0, '[M+H]+': 1} mode_count = len(mode_map) left_forward, left_backward = Metabolite._edge_count_cols( - mode_map, mode_count, "[M+H]+", "left" + mode_map, mode_count, '[M+H]+', 'left' ) right_forward, right_backward = Metabolite._edge_count_cols( - mode_map, mode_count, "[M+H]+", "right" + mode_map, mode_count, '[M+H]+', 'right' ) assert (left_forward, left_backward) == (1, 3) @@ -169,10 +166,10 @@ def test_edge_count_cols_helper(): def test_collision_energy_helpers(): assert NCE_to_eV(20.0, 250.0) == pytest.approx(10.0) - assert align_CE("35eV", 200.0) == pytest.approx(35.0) - assert align_CE("2keV", 200.0) == pytest.approx(2000.0) + assert align_CE('35eV', 200.0) == pytest.approx(35.0) + assert align_CE('2keV', 200.0) == pytest.approx(2000.0) assert align_CE(20.0, 250.0, instrument=nce_instruments[0]) == pytest.approx(10.0) - assert align_CE("15% (nominal)", 300.0) == pytest.approx(NCE_to_eV(15.0, 300.0)) + assert align_CE('15% (nominal)', 300.0) == pytest.approx(NCE_to_eV(15.0, 300.0)) def test_edge_count_matrix_accumulates_repeated_matches(): @@ -191,25 +188,25 @@ def __init__(self, peak_matches): def match_peak_list(self, mz_list, int_list, tolerance=None): return self.peak_matches - mode_map = {"[M+H]+": 0} + mode_map = {'[M+H]+': 0} edge = (0, 1) peak_matches = { 100.0: { - "intensity": 10.0, - "relative_intensity": 10.0 / 15.0, - "fragments": [_StubFragment(edge=edge, break_side="left")], - "ion_modes": [("[M+H]+", 100.0)], + 'intensity': 10.0, + 'relative_intensity': 10.0 / 15.0, + 'fragments': [_StubFragment(edge=edge, break_side='left')], + 'ion_modes': [('[M+H]+', 100.0)], }, 101.0: { - "intensity": 5.0, - "relative_intensity": 5.0 / 15.0, - "fragments": [_StubFragment(edge=edge, break_side="left")], - "ion_modes": [("[M+H]+", 101.0)], + 'intensity': 5.0, + 'relative_intensity': 5.0 / 15.0, + 'fragments': [_StubFragment(edge=edge, break_side='left')], + 'ion_modes': [('[M+H]+', 101.0)], }, } - metabolite = Metabolite("CC") + metabolite = Metabolite('CC') metabolite.create_molecular_structure_graph() metabolite.compute_graph_attributes() metabolite.fragmentation_tree = _StubTree(peak_matches) @@ -221,7 +218,7 @@ def match_peak_list(self, mz_list, int_list, tolerance=None): ) forward_col, backward_col = Metabolite._edge_count_cols( - mode_map, len(mode_map), "[M+H]+", "left" + mode_map, len(mode_map), '[M+H]+', 'left' ) forward_idx = ( ((torch.tensor(edge) == metabolite.edges).sum(dim=1) == 2).nonzero().squeeze() diff --git a/tests/test_trainer_history.py b/tests/test_trainer_history.py index e7f38ee..becf5fd 100644 --- a/tests/test_trainer_history.py +++ b/tests/test_trainer_history.py @@ -21,13 +21,13 @@ def test_update_history_supports_mae_only_stats(): trainer._init_history() trainer._update_history( epoch=1, - train_stats={"mae": torch.tensor(0.5)}, - val_stats={"mae": torch.tensor(0.75)}, + train_stats={'mae': torch.tensor(0.5)}, + val_stats={'mae': torch.tensor(0.75)}, lr=1e-3, ) - assert trainer.history["train_error"] == [0.5] - assert trainer.history["val_error"] == [0.75] - assert math.isnan(trainer.history["sqrt_train_error"][0]) - assert math.isnan(trainer.history["sqrt_val_error"][0]) - assert trainer.history["lr"] == [1e-3] + assert trainer.history['train_error'] == [0.5] + assert trainer.history['val_error'] == [0.75] + assert math.isnan(trainer.history['sqrt_train_error'][0]) + assert math.isnan(trainer.history['sqrt_val_error'][0]) + assert trainer.history['lr'] == [1e-3] From 75b57da8f294610291fc63f696277b3cd60c3d19 Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Sun, 19 Apr 2026 09:34:55 +0200 Subject: [PATCH 11/15] applied ruff --- .pre-commit-config.yaml | 2 + lib_loader/casmi16_loader.ipynb | 443 ++- lib_loader/casmi22_loader.ipynb | 276 +- lib_loader/gnps_library_loader.ipynb | 1067 +++--- lib_loader/ms_dial_loader.ipynb | 352 +- lib_loader/msnlib_loader.ipynb | 508 ++- lib_loader/nist_library_loader.ipynb | 587 ++- notebooks/break_tendency.ipynb | 1750 +++++---- notebooks/grid_search.ipynb | 465 +-- notebooks/grid_stats.ipynb | 180 +- notebooks/info_graphs.ipynb | 330 +- notebooks/live_predict.ipynb | 129 +- notebooks/sandbox.ipynb | 874 ++--- notebooks/test_model.ipynb | 3725 ++++++++++---------- notebooks/train_model.ipynb | 936 +++-- resources/data/msnlib/download_msnlib.py | 103 +- resources/data/msnlib/preprocess_msnlib.py | 356 +- 17 files changed, 6013 insertions(+), 6070 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e71d57..c52674e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,4 +4,6 @@ repos: hooks: - id: ruff args: [--fix] + files: ^(fiora|tests)/.*\.py$ - id: ruff-format + files: ^(fiora|tests)/.*\.py$ diff --git a/lib_loader/casmi16_loader.ipynb b/lib_loader/casmi16_loader.ipynb index c31ee08..e004a12 100644 --- a/lib_loader/casmi16_loader.ipynb +++ b/lib_loader/casmi16_loader.ipynb @@ -32,53 +32,45 @@ "source": [ "import sys\n", "\n", - "print(f\"Working with Python {sys.version}\")\n", + "print(f'Working with Python {sys.version}')\n", "\n", - "import numpy as np\n", - "import pandas as pd\n", + "import collections\n", "import importlib\n", + "import os\n", + "import time\n", + "\n", + "# Load Modules\n", + "from os.path import expanduser\n", "\n", "# import swifter\n", "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import collections\n", - "import time\n", - "import os\n", - "from rdkit import Chem\n", - "from rdkit.Chem import AllChem\n", - "from rdkit.Chem import Draw\n", + "import numpy as np\n", + "import pandas as pd\n", "import rdkit.Chem.Descriptors as Descriptors\n", - "from rdkit.Chem import PandasTools\n", - "#!pip install spektral\n", - "\n", + "import seaborn as sns\n", "\n", + "#!pip install spektral\n", "# Deep Learning\n", "import sklearn\n", "\n", - "# import spektral\n", - "from sklearn.model_selection import train_test_split\n", - "\n", "# Keras\n", - "from sklearn.model_selection import train_test_split\n", - "\n", "# import stellargraph as sg\n", - "from rdkit import RDLogger\n", + "from rdkit import Chem, RDLogger\n", + "from rdkit.Chem import AllChem, Draw, PandasTools\n", "\n", + "# import spektral\n", + "from sklearn.model_selection import train_test_split\n", "\n", - "# Load Modules\n", - "from os.path import expanduser\n", - "\n", - "home = expanduser(\"~\")\n", - "import fiora.IO.mspReader as mspReader\n", + "home = expanduser('~')\n", "import fiora.IO.mgfReader as mgfReader\n", - "import fiora.visualization.spectrum_visualizer as sv\n", "import fiora.IO.molReader as molReader\n", + "import fiora.IO.mspReader as mspReader\n", + "import fiora.visualization.spectrum_visualizer as sv\n", "\n", - "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", "\n", - "caffeine_smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", + "caffeine_smiles = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'\n", "caffeine_mol = Chem.MolFromSmiles(caffeine_smiles)\n", "\n", "caffeine_mol" @@ -119,7 +111,7 @@ ], "source": [ "library_directory = (\n", - " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_negative_mgf\"\n", + " f'{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_negative_mgf'\n", ")\n", "!ls $library_directory" ] @@ -141,22 +133,22 @@ "source": [ "df = []\n", "library_directory = (\n", - " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_negative_mgf\"\n", + " f'{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_negative_mgf'\n", ")\n", "\n", "for file in os.listdir(library_directory):\n", - " if file.endswith(\".mgf\"):\n", - " data = mgfReader.read(os.path.join(library_directory, file), sep=\"\\t\")[0]\n", - " data[\"FILE\"] = file\n", - " data[\"Precursor_type\"] = \"[M-H]-\"\n", + " if file.endswith('.mgf'):\n", + " data = mgfReader.read(os.path.join(library_directory, file), sep='\\t')[0]\n", + " data['FILE'] = file\n", + " data['Precursor_type'] = '[M-H]-'\n", " df += [data]\n", "\n", - "library_directory = library_directory.replace(\"negative\", \"positive\")\n", + "library_directory = library_directory.replace('negative', 'positive')\n", "for file in os.listdir(library_directory):\n", - " if file.endswith(\".mgf\"):\n", - " data = mgfReader.read(os.path.join(library_directory, file), sep=\"\\t\")[0]\n", - " data[\"FILE\"] = file\n", - " data[\"Precursor_type\"] = \"[M+H]+\"\n", + " if file.endswith('.mgf'):\n", + " data = mgfReader.read(os.path.join(library_directory, file), sep='\\t')[0]\n", + " data['FILE'] = file\n", + " data['Precursor_type'] = '[M+H]+'\n", " df += [data]\n", "\n", "df = pd.DataFrame(df)" @@ -250,7 +242,7 @@ ], "source": [ "solution = pd.read_csv(\n", - " os.path.join(library_directory, \"solutions_casmi2016_cat2and3.csv\")\n", + " os.path.join(library_directory, 'solutions_casmi2016_cat2and3.csv')\n", ")\n", "# solution = solution[solution[\"ION_MODE\"] == \" POSITIVE\"]\n", "# solution.reset_index(inplace=True, drop=True)\n", @@ -276,7 +268,7 @@ ], "source": [ "# Check that challenges and solutions are aligned correctly\n", - "df.apply(lambda x: x[\"FILE\"].split(\".\")[0] == x[\"ChallengeName\"], axis=1).all()" + "df.apply(lambda x: x['FILE'].split('.')[0] == x['ChallengeName'], axis=1).all()" ] }, { @@ -299,14 +291,14 @@ ], "source": [ "library_candidates_directory = (\n", - " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_Candidates\"\n", + " f'{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Challenge_Candidates'\n", ")\n", "candidates = []\n", "for file in os.listdir(library_candidates_directory):\n", - " if file.endswith(\".csv\"):\n", - " data = pd.read_csv(os.path.join(library_candidates_directory, file), sep=\",\")\n", - " SMILES = list(data[\"SMILES\"])\n", - " d = {\"cFILE\": file, \"Candidates\": SMILES}\n", + " if file.endswith('.csv'):\n", + " data = pd.read_csv(os.path.join(library_candidates_directory, file), sep=',')\n", + " SMILES = list(data['SMILES'])\n", + " d = {'cFILE': file, 'Candidates': SMILES}\n", " candidates += [d]\n", "\n", "candidates = pd.DataFrame(candidates) # .loc[81:].reset_index(drop=True)\n", @@ -331,7 +323,7 @@ ], "source": [ "df = pd.concat([df, candidates], axis=1)\n", - "df.apply(lambda x: x[\"cFILE\"].split(\".\")[0] == x[\"ChallengeName\"], axis=1).all()" + "df.apply(lambda x: x['cFILE'].split('.')[0] == x['ChallengeName'], axis=1).all()" ] }, { @@ -371,13 +363,13 @@ "outputs": [], "source": [ "save_df = False\n", - "name = \"casmi16_challenges_combined.csv\"\n", + "name = 'casmi16_challenges_combined.csv'\n", "\n", - "library_directory = \"/\".join(library_directory.split(\"/\")[:-1])\n", + "library_directory = '/'.join(library_directory.split('/')[:-1])\n", "print(library_directory)\n", "if save_df:\n", " file = os.path.join(library_directory, name)\n", - " print(\"saving to \", file)\n", + " print('saving to ', file)\n", " df.to_csv(file)\n", "\n", " # df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" @@ -522,7 +514,7 @@ } ], "source": [ - "df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"]" + "df_cfm[df_cfm['Precursor_type'] == '[M-H]-']" ] }, { @@ -546,21 +538,21 @@ "source": [ "## prepare output for for CFM-ID\n", "save_df = False\n", - "cfm_directory = f\"{home}/data/metabolites/cfm-id/\"\n", - "name = \"casmi16_negative_solutions_cfm.txt\"\n", - "df_cfm = df[[\"ChallengeName\", \"SMILES\", \"Precursor_type\"]]\n", + "cfm_directory = f'{home}/data/metabolites/cfm-id/'\n", + "name = 'casmi16_negative_solutions_cfm.txt'\n", + "df_cfm = df[['ChallengeName', 'SMILES', 'Precursor_type']]\n", "print(df_cfm.head())\n", "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", - " df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"][[\"ChallengeName\", \"SMILES\"]].to_csv(\n", - " file, index=False, header=False, sep=\" \"\n", + " df_cfm[df_cfm['Precursor_type'] == '[M-H]-'][['ChallengeName', 'SMILES']].to_csv(\n", + " file, index=False, header=False, sep=' '\n", " )\n", "\n", - " name = name.replace(\"negative\", \"positive\")\n", + " name = name.replace('negative', 'positive')\n", " file = os.path.join(cfm_directory, name)\n", - " df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"][[\"ChallengeName\", \"SMILES\"]].to_csv(\n", - " file, index=False, header=False, sep=\" \"\n", + " df_cfm[df_cfm['Precursor_type'] == '[M+H]+'][['ChallengeName', 'SMILES']].to_csv(\n", + " file, index=False, header=False, sep=' '\n", " )" ] }, @@ -818,16 +810,16 @@ "from rdkit import Chem\n", "from rdkit.Chem import rdMolDescriptors\n", "\n", - "df[\"MOL\"] = df[\"SMILES\"].apply(Chem.MolFromSmiles)\n", - "df[\"formula\"] = df[\"MOL\"].apply(rdMolDescriptors.CalcMolFormula)\n", - "df[\"dataset\"] = \"CASMI16\"\n", + "df['MOL'] = df['SMILES'].apply(Chem.MolFromSmiles)\n", + "df['formula'] = df['MOL'].apply(rdMolDescriptors.CalcMolFormula)\n", + "df['dataset'] = 'CASMI16'\n", "df = df.rename(\n", " columns={\n", - " \"FILE\": \"spec\",\n", - " \"ChallengeName\": \"name\",\n", - " \"Precursor_type\": \"ionization\",\n", - " \"SMILES\": \"smiles\",\n", - " \"INCHIKEY\": \"inchikey\",\n", + " 'FILE': 'spec',\n", + " 'ChallengeName': 'name',\n", + " 'Precursor_type': 'ionization',\n", + " 'SMILES': 'smiles',\n", + " 'INCHIKEY': 'inchikey',\n", " }\n", ")" ] @@ -840,14 +832,14 @@ "source": [ "save_df = False\n", "if save_df:\n", - " output_file = f\"{home}/data/metabolites/ms-pred/casmi16_labels.tsv\"\n", + " output_file = f'{home}/data/metabolites/ms-pred/casmi16_labels.tsv'\n", " df[\n", - " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", - " ].to_csv(output_file, index=False, sep=\"\\t\")\n", - " output_file = f\"{home}/data/metabolites/ms-pred/casmi16_positive_labels.tsv\"\n", - " df[df[\"ionization\"] == \"[M+H]+\"][\n", - " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", - " ].to_csv(output_file, index=False, sep=\"\\t\")" + " ['dataset', 'spec', 'name', 'formula', 'ionization', 'smiles', 'inchikey']\n", + " ].to_csv(output_file, index=False, sep='\\t')\n", + " output_file = f'{home}/data/metabolites/ms-pred/casmi16_positive_labels.tsv'\n", + " df[df['ionization'] == '[M+H]+'][\n", + " ['dataset', 'spec', 'name', 'formula', 'ionization', 'smiles', 'inchikey']\n", + " ].to_csv(output_file, index=False, sep='\\t')" ] }, { @@ -873,36 +865,36 @@ "outputs": [], "source": [ "%%capture\n", - "from modules.MOL.Metabolite import Metabolite\n", "from modules.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", "from modules.GNN.BondFeatureEncoder import BondFeatureEncoder\n", "from modules.GNN.SetupFeatureEncoder import SetupFeatureEncoder\n", + "from modules.MOL.Metabolite import Metabolite\n", "\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", - "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df['Metabolite'] = df['SMILES'].apply(Metabolite)\n", + "df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "setup_encoder = SetupFeatureEncoder(\n", - " feature_list=[\"collision_energy\", \"molecular_weight\"]\n", + " feature_list=['collision_energy', 'molecular_weight']\n", ")\n", - "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", + "df['Metabolite'].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", "\n", "\n", - "df[\"CE\"] = 35.0 # 20/35/50\n", - "df[\"Instrument_type\"] = \"HCD\"\n", - "df[\"Ionization\"] = \"ESI-MS/MS\"\n", + "df['CE'] = 35.0 # 20/35/50\n", + "df['Instrument_type'] = 'HCD'\n", + "df['Ionization'] = 'ESI-MS/MS'\n", "metadata_key_map = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'ionization': 'Ionization',\n", " # \"precursor_mz\": \"PrecursorMZ\"\n", "}\n", - "df[\"summary\"] = df.apply(\n", + "df['summary'] = df.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", ")\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder=None), axis=1\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder=None), axis=1\n", ")" ] }, @@ -921,11 +913,11 @@ } ], "source": [ - "print(\"Assigning unique metabolite identifiers.\")\n", + "print('Assigning unique metabolite identifiers.')\n", "\n", "metabolite_id_map = {}\n", "\n", - "for metabolite in df[\"Metabolite\"]:\n", + "for metabolite in df['Metabolite']:\n", " is_new = True\n", " for id, other in metabolite_id_map.items():\n", " if metabolite == other:\n", @@ -937,7 +929,7 @@ " metabolite.id = new_id\n", " metabolite_id_map[new_id] = metabolite\n", "\n", - "print(f\"Found {len(metabolite_id_map)} unique molecular structures.\")" + "print(f'Found {len(metabolite_id_map)} unique molecular structures.')" ] }, { @@ -959,14 +951,13 @@ "source": [ "from modules.MOL.mol_graph import draw_graph\n", "from modules.visualization.define_colors import *\n", - "import matplotlib.pyplot as plt\n", "\n", "EXAMPLE_ID = 0\n", "example = df.loc[EXAMPLE_ID]\n", - "m = example[\"Metabolite\"]\n", + "m = example['Metabolite']\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + " 1, 2, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1]}, sharey=False\n", ")\n", "set_light_theme()\n", "\n", @@ -1021,74 +1012,73 @@ } ], "source": [ + "import matplotlib.pyplot as plt\n", "from modules.MOL.mol_graph import draw_graph\n", "from modules.visualization.define_colors import *\n", - "import matplotlib.pyplot as plt\n", - "\n", "\n", "EXAMPLE_ID = 0\n", "example = df.loc[EXAMPLE_ID]\n", - "m = example[\"Metabolite\"]\n", + "m = example['Metabolite']\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 4, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1, 1, 1]}, sharey=False\n", + " 1, 4, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1, 1, 1]}, sharey=False\n", ")\n", "set_light_theme()\n", "\n", "img = m.draw(ax=axs[0])\n", - "m2 = df.loc[EXAMPLE_ID + 2][\"Metabolite\"]\n", + "m2 = df.loc[EXAMPLE_ID + 2]['Metabolite']\n", "img = m2.draw(ax=axs[1])\n", - "img = df.loc[EXAMPLE_ID + 1][\"Metabolite\"].draw(ax=axs[2])\n", - "img = df.loc[EXAMPLE_ID + 3][\"Metabolite\"].draw(ax=axs[3])\n", + "img = df.loc[EXAMPLE_ID + 1]['Metabolite'].draw(ax=axs[2])\n", + "img = df.loc[EXAMPLE_ID + 3]['Metabolite'].draw(ax=axs[3])\n", "plt.show()\n", "\n", "\n", "EXAMPLE_ID = 60\n", "example = df.loc[EXAMPLE_ID]\n", - "m = example[\"Metabolite\"]\n", + "m = example['Metabolite']\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 4, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1, 1, 1]}, sharey=False\n", + " 1, 4, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1, 1, 1]}, sharey=False\n", ")\n", "set_light_theme()\n", "\n", "img = m.draw(ax=axs[0])\n", - "m2 = df.loc[EXAMPLE_ID + 2][\"Metabolite\"]\n", + "m2 = df.loc[EXAMPLE_ID + 2]['Metabolite']\n", "img = m2.draw(ax=axs[1])\n", - "img = df.loc[EXAMPLE_ID + 1][\"Metabolite\"].draw(ax=axs[2])\n", - "img = df.loc[EXAMPLE_ID + 3][\"Metabolite\"].draw(ax=axs[3])\n", + "img = df.loc[EXAMPLE_ID + 1]['Metabolite'].draw(ax=axs[2])\n", + "img = df.loc[EXAMPLE_ID + 3]['Metabolite'].draw(ax=axs[3])\n", "plt.show()\n", "\n", "EXAMPLE_ID = 200\n", "example = df.loc[EXAMPLE_ID]\n", - "m = example[\"Metabolite\"]\n", + "m = example['Metabolite']\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 4, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1, 1, 1]}, sharey=False\n", + " 1, 4, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1, 1, 1]}, sharey=False\n", ")\n", "set_light_theme()\n", "\n", "img = m.draw(ax=axs[0])\n", - "m2 = df.loc[EXAMPLE_ID + 2][\"Metabolite\"]\n", + "m2 = df.loc[EXAMPLE_ID + 2]['Metabolite']\n", "img = m2.draw(ax=axs[1])\n", - "img = df.loc[EXAMPLE_ID + 1][\"Metabolite\"].draw(ax=axs[2])\n", - "img = df.loc[EXAMPLE_ID + 3][\"Metabolite\"].draw(ax=axs[3])\n", + "img = df.loc[EXAMPLE_ID + 1]['Metabolite'].draw(ax=axs[2])\n", + "img = df.loc[EXAMPLE_ID + 3]['Metabolite'].draw(ax=axs[3])\n", "plt.show()\n", "\n", "EXAMPLE_ID = 120\n", "example = df.loc[EXAMPLE_ID]\n", - "m = example[\"Metabolite\"]\n", + "m = example['Metabolite']\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 4, figsize=(8, 4), gridspec_kw={\"width_ratios\": [1, 1, 1, 1]}, sharey=False\n", + " 1, 4, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1, 1, 1]}, sharey=False\n", ")\n", "set_light_theme()\n", "\n", "img = m.draw(ax=axs[0])\n", - "m2 = df.loc[EXAMPLE_ID + 2][\"Metabolite\"]\n", + "m2 = df.loc[EXAMPLE_ID + 2]['Metabolite']\n", "img = m2.draw(ax=axs[1])\n", - "img = df.loc[EXAMPLE_ID + 1][\"Metabolite\"].draw(ax=axs[2])\n", - "img = df.loc[EXAMPLE_ID + 3][\"Metabolite\"].draw(ax=axs[3])\n", + "img = df.loc[EXAMPLE_ID + 1]['Metabolite'].draw(ax=axs[2])\n", + "img = df.loc[EXAMPLE_ID + 3]['Metabolite'].draw(ax=axs[3])\n", "plt.show()" ] }, @@ -1125,14 +1115,14 @@ } ], "source": [ - "color_palette = define_figure_style(\"magma_white\")\n", + "color_palette = define_figure_style('magma_white')\n", "\n", "\n", "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), sharey=False)\n", "\n", "edges_bond_types = [\n", " item\n", - " for items in list(df[\"Metabolite\"].apply(lambda x: getattr(x, \"edge_bond_names\")))\n", + " for items in list(df['Metabolite'].apply(lambda x: getattr(x, 'edge_bond_names')))\n", " for item in items\n", "]\n", "bond_types = {\n", @@ -1146,19 +1136,19 @@ " x=list(bond_types.keys()),\n", " y=list(bond_types.values()),\n", " palette=color_palette,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " linewidth=1.5,\n", ")\n", "_, labels, autotexts = axs[1].pie(\n", " list(bond_types.values()),\n", " labels=list(bond_types.keys()),\n", " colors=color_palette,\n", - " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", - " autopct=\"%1.0f%%\",\n", + " wedgeprops={'edgecolor': 'black', 'linewidth': 1.5},\n", + " autopct='%1.0f%%',\n", ")\n", "\n", "for i in range(len(labels)):\n", - " if labels[i].get_text() == \"TRIPLE\":\n", + " if labels[i].get_text() == 'TRIPLE':\n", " autotexts[i].remove()\n", "plt.show()" ] @@ -1198,13 +1188,13 @@ "source": [ "from modules.visualization.define_colors import define_figure_style\n", "\n", - "color_palette = define_figure_style(\"magma_white\")\n", + "color_palette = define_figure_style('magma_white')\n", "\n", "\n", "elems = [\n", " e\n", " for mol in list(\n", - " df[\"Metabolite\"].apply(lambda x: getattr(x, \"node_elements\")).values\n", + " df['Metabolite'].apply(lambda x: getattr(x, 'node_elements')).values\n", " )\n", " for e in mol\n", "]\n", @@ -1216,25 +1206,25 @@ " x=list(elem_types.keys()),\n", " y=list(elem_types.values()),\n", " palette=color_palette,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " linewidth=1.5,\n", ")\n", "_, labels, autotexts = axs[1].pie(\n", " list(elem_types.values()),\n", " colors=color_palette,\n", - " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", - " autopct=\"%1.0f%%\",\n", + " wedgeprops={'edgecolor': 'black', 'linewidth': 1.5},\n", + " autopct='%1.0f%%',\n", ")\n", "\n", "axs[1].legend(list(elem_types.keys()))\n", "plt.show()\n", "print(elem_types)\n", "cno = (\n", - " (elem_types[\"C\"] + elem_types[\"O\"] + elem_types[\"N\"])\n", + " (elem_types['C'] + elem_types['O'] + elem_types['N'])\n", " * 100\n", " / sum(elem_types.values())\n", ")\n", - "print(f\"With {cno:.01f}% CNO\")" + "print(f'With {cno:.01f}% CNO')" ] }, { @@ -1264,8 +1254,8 @@ "elems = [\n", " e\n", " for mol in list(\n", - " df[\"Metabolite\"]\n", - " .apply(lambda x: [a.GetTotalNumHs() for a in getattr(x, \"atoms_in_order\")])\n", + " df['Metabolite']\n", + " .apply(lambda x: [a.GetTotalNumHs() for a in getattr(x, 'atoms_in_order')])\n", " .values\n", " )\n", " for e in mol\n", @@ -1278,19 +1268,19 @@ " x=list(elem_types.keys()),\n", " y=list(elem_types.values()),\n", " palette=color_palette,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " linewidth=1.5,\n", ")\n", "_, labels, autotexts = axs[1].pie(\n", " list(elem_types.values()),\n", " colors=color_palette,\n", " labels=list(elem_types.keys()),\n", - " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", - " autopct=\"%1.0f%%\",\n", + " wedgeprops={'edgecolor': 'black', 'linewidth': 1.5},\n", + " autopct='%1.0f%%',\n", ")\n", "\n", "axs[1].legend(list(elem_types.keys()))\n", - "plt.title(\"Number of bonded hydrogens\")\n", + "plt.title('Number of bonded hydrogens')\n", "plt.show()\n", "print(elem_types)" ] @@ -1315,21 +1305,21 @@ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", "d = {\n", - " \"collision_energy\": df[\"Metabolite\"].apply(lambda x: x.metadata[\"collision_energy\"])\n", + " 'collision_energy': df['Metabolite'].apply(lambda x: x.metadata['collision_energy'])\n", "}\n", "\n", "sns.histplot(\n", " ax=ax,\n", " data=d,\n", - " x=\"collision_energy\",\n", - " color=\"blue\",\n", + " x='collision_energy',\n", + " color='blue',\n", " fill=True,\n", " binwidth=2,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " binrange=[0, 200],\n", ") # , order=list(range(0,200)))\n", - "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", + "plt.rcParams['patch.force_edgecolor'] = True\n", + "plt.rcParams['axes.edgecolor'] = 'black'\n", "plt.show()" ] }, @@ -1353,21 +1343,21 @@ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", "d = {\n", - " \"molecular_weight\": df[\"Metabolite\"].apply(lambda x: x.metadata[\"molecular_weight\"])\n", + " 'molecular_weight': df['Metabolite'].apply(lambda x: x.metadata['molecular_weight'])\n", "}\n", "\n", "sns.histplot(\n", " ax=ax,\n", " data=d,\n", - " x=\"molecular_weight\",\n", + " x='molecular_weight',\n", " color=color_palette[3],\n", " fill=True,\n", " binwidth=2,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " binrange=[0, 1000],\n", ") # , order=list(range(0,200)))\n", - "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", + "plt.rcParams['patch.force_edgecolor'] = True\n", + "plt.rcParams['axes.edgecolor'] = 'black'\n", "plt.show()" ] }, @@ -1388,10 +1378,10 @@ "%%capture\n", "from modules.MOL.constants import PPM\n", "\n", - "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=100 * PPM\n", " ),\n", " axis=1,\n", ")" @@ -1405,16 +1395,16 @@ "source": [ "d100 = pd.DataFrame(\n", " {\n", - " \"num_peak_matches\": df[\"Metabolite\"].apply(\n", - " lambda x: x.match_stats[\"num_peak_matches\"]\n", + " 'num_peak_matches': df['Metabolite'].apply(\n", + " lambda x: x.match_stats['num_peak_matches']\n", " ),\n", - " \"num_non_precursor_matches\": df[\"Metabolite\"].apply(\n", - " lambda x: x.match_stats[\"num_non_precursor_matches\"]\n", + " 'num_non_precursor_matches': df['Metabolite'].apply(\n", + " lambda x: x.match_stats['num_non_precursor_matches']\n", " ),\n", - " \"num_peak_match_conflicts\": df[\"Metabolite\"].apply(\n", - " lambda x: x.match_stats[\"num_peak_match_conflicts\"]\n", + " 'num_peak_match_conflicts': df['Metabolite'].apply(\n", + " lambda x: x.match_stats['num_peak_match_conflicts']\n", " ),\n", - " \"group\": \"100 PPM\",\n", + " 'group': '100 PPM',\n", " }\n", ")" ] @@ -1446,49 +1436,49 @@ "# TODO Implement conflict solver\n", "\n", "coverage_tracker = {\n", - " \"counts\": [],\n", - " \"all\": [],\n", - " \"coverage\": [],\n", - " \"fragment_only_coverage\": [],\n", + " 'counts': [],\n", + " 'all': [],\n", + " 'coverage': [],\n", + " 'fragment_only_coverage': [],\n", "}\n", "\n", "drop_index = []\n", "for i, d in df.iterrows():\n", - " M = d[\"Metabolite\"]\n", + " M = d['Metabolite']\n", "\n", - " coverage_tracker[\"counts\"] += [M.match_stats[\"counts\"]]\n", - " coverage_tracker[\"all\"] += [M.match_stats[\"ms_all_counts\"]]\n", - " coverage_tracker[\"fragment_only_coverage\"] += [M.match_stats[\"coverage_wo_prec\"]]\n", - " coverage_tracker[\"coverage\"] += [M.match_stats[\"coverage\"]]\n", + " coverage_tracker['counts'] += [M.match_stats['counts']]\n", + " coverage_tracker['all'] += [M.match_stats['ms_all_counts']]\n", + " coverage_tracker['fragment_only_coverage'] += [M.match_stats['coverage_wo_prec']]\n", + " coverage_tracker['coverage'] += [M.match_stats['coverage']]\n", "\n", " # if M.edge_break_prob_wo_precursor.sum() <= 0.01:\n", " # drop_index.append(i)\n", " # if M.edge_break_prob.sum() < 0.05: # TODO\n", " # drop_index.append(i)\n", "\n", - " if M.match_stats[\"coverage\"] < 0.25: # Filter if total coverage is too low\n", + " if M.match_stats['coverage'] < 0.25: # Filter if total coverage is too low\n", " drop_index.append(i)\n", " # if M.match_stats[\"coverage_wo_prec\"] < 0.1: # Filter if fragment coverage is too low (intensity wise)\n", " # drop_index.append(i)\n", "\n", "# filter low res instruments TODO update to low quality spectra\n", - "is_iontrap = df[\"Metabolite\"].apply(lambda x: x.metadata[\"instrument\"] == \"IT/ion trap\")\n", + "is_iontrap = df['Metabolite'].apply(lambda x: x.metadata['instrument'] == 'IT/ion trap')\n", "drop_index += list(df[is_iontrap].index)\n", "\n", "fig, axs = plt.subplots(1, 3, figsize=(12.8, 4.2), sharey=True)\n", "\n", "plt.ylim([-0.02, 1.02])\n", "sns.boxplot(\n", - " ax=axs[0], data=coverage_tracker, y=\"fragment_only_coverage\", color=color_palette[1]\n", + " ax=axs[0], data=coverage_tracker, y='fragment_only_coverage', color=color_palette[1]\n", ")\n", - "sns.boxplot(ax=axs[1], data=coverage_tracker, y=\"coverage\", color=color_palette[2])\n", - "sns.swarmplot(ax=axs[2], data=coverage_tracker, y=\"coverage\", color=color_palette[2])\n", - "axs[0].set_title(\"Coverage of peak intensity (fragments only)\")\n", - "axs[1].set_title(\"Coverage of peak intensity\")\n", - "axs[2].set_title(\"Coverage of peak intensity\")\n", + "sns.boxplot(ax=axs[1], data=coverage_tracker, y='coverage', color=color_palette[2])\n", + "sns.swarmplot(ax=axs[2], data=coverage_tracker, y='coverage', color=color_palette[2])\n", + "axs[0].set_title('Coverage of peak intensity (fragments only)')\n", + "axs[1].set_title('Coverage of peak intensity')\n", + "axs[2].set_title('Coverage of peak intensity')\n", "plt.show()\n", "\n", - "print(f\"Filtering would drop {len(drop_index)} out of {df.shape[0]}\")" + "print(f'Filtering would drop {len(drop_index)} out of {df.shape[0]}')" ] }, { @@ -1664,10 +1654,11 @@ ], "source": [ "import modules.IO.cfmReader as cfmReader\n", + "\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", "cf = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt', as_df=True\n", ")\n", "cf.head(2)" ] @@ -1708,7 +1699,7 @@ ], "source": [ "library_directory = (\n", - " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Training_negative_mgf\"\n", + " f'{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Training_negative_mgf'\n", ")\n", "!ls $library_directory" ] @@ -1722,24 +1713,24 @@ "df = []\n", "\n", "for file in os.listdir(library_directory):\n", - " if file.endswith(\".mgf\"):\n", + " if file.endswith('.mgf'):\n", " data = mgfReader.read(\n", - " os.path.join(library_directory, file), sep=\"\\t\", debug=False\n", + " os.path.join(library_directory, file), sep='\\t', debug=False\n", " )[0]\n", - " data[\"FILE\"] = file\n", - " data[\"Precursor_type\"] = \"[M-H]-\"\n", + " data['FILE'] = file\n", + " data['Precursor_type'] = '[M-H]-'\n", " df += [data]\n", "\n", - "library_directory = library_directory.replace(\"negative\", \"positive\")\n", + "library_directory = library_directory.replace('negative', 'positive')\n", "for file in os.listdir(library_directory):\n", - " if file.endswith(\".mgf\"):\n", - " data = mgfReader.read(os.path.join(library_directory, file), sep=\"\\t\")[0]\n", - " data[\"FILE\"] = file\n", - " data[\"Precursor_type\"] = \"[M+H]+\"\n", + " if file.endswith('.mgf'):\n", + " data = mgfReader.read(os.path.join(library_directory, file), sep='\\t')[0]\n", + " data['FILE'] = file\n", + " data['Precursor_type'] = '[M+H]+'\n", " df += [data]\n", "\n", "df = pd.DataFrame(df)\n", - "df[\"ChallengeName\"] = df[\"FILE\"].apply(lambda x: x.split(\".\")[0])" + "df['ChallengeName'] = df['FILE'].apply(lambda x: x.split('.')[0])" ] }, { @@ -1861,15 +1852,15 @@ ], "source": [ "solution = pd.read_csv(\n", - " os.path.join(library_directory, \"..\", \"CASMI2016_Cat2and3_Training.csv\")\n", + " os.path.join(library_directory, '..', 'CASMI2016_Cat2and3_Training.csv')\n", ")\n", "# solution = solution[solution[\"ION_MODE\"] == \" POSITIVE\"]\n", "# solution.reset_index(inplace=True, drop=True)\n", "# df = pd.concat([df, solution], axis=1)\n", "\n", - "df = pd.merge(df, solution, left_on=\"ChallengeName\", right_on=\"challengename\")\n", + "df = pd.merge(df, solution, left_on='ChallengeName', right_on='challengename')\n", "\n", - "print(df[\"SMILES\"].head(2))" + "print(df['SMILES'].head(2))" ] }, { @@ -1912,8 +1903,8 @@ "source": [ "import seaborn as sns\n", "\n", - "print(df.groupby(\"Precursor_type\")[\"coverage100PPM\"].median())\n", - "sns.kdeplot(data=df, x=\"coverage100PPM\", hue=\"Precursor_type\")" + "print(df.groupby('Precursor_type')['coverage100PPM'].median())\n", + "sns.kdeplot(data=df, x='coverage100PPM', hue='Precursor_type')" ] }, { @@ -1943,14 +1934,14 @@ ], "source": [ "library_candidates_directory = (\n", - " f\"{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Training_Candidates\"\n", + " f'{home}/data/metabolites/CASMI_2016/CASMI2016_Cat2and3_Training_Candidates'\n", ")\n", "candidates = []\n", "for file in os.listdir(library_candidates_directory):\n", - " if file.endswith(\".csv\"):\n", - " data = pd.read_csv(os.path.join(library_candidates_directory, file), sep=\",\")\n", - " SMILES = list(data[\"SMILES\"])\n", - " d = {\"ChallengeName\": file.split(\".\")[0], \"cFILE\": file, \"Candidates\": SMILES}\n", + " if file.endswith('.csv'):\n", + " data = pd.read_csv(os.path.join(library_candidates_directory, file), sep=',')\n", + " SMILES = list(data['SMILES'])\n", + " d = {'ChallengeName': file.split('.')[0], 'cFILE': file, 'Candidates': SMILES}\n", " candidates += [d]\n", "\n", "candidates = pd.DataFrame(candidates) # .loc[81:].reset_index(drop=True)\n", @@ -1974,8 +1965,8 @@ } ], "source": [ - "df = pd.merge(df, candidates, on=\"ChallengeName\")\n", - "df.apply(lambda x: x[\"cFILE\"].split(\".\")[0] == x[\"ChallengeName\"], axis=1).all()" + "df = pd.merge(df, candidates, on='ChallengeName')\n", + "df.apply(lambda x: x['cFILE'].split('.')[0] == x['ChallengeName'], axis=1).all()" ] }, { @@ -2006,7 +1997,7 @@ } ], "source": [ - "df[\"ChallengeName\"]" + "df['ChallengeName']" ] }, { @@ -2025,13 +2016,13 @@ ], "source": [ "save_df = False\n", - "name = \"casmi16t.csv\"\n", + "name = 'casmi16t.csv'\n", "\n", - "library_directory = \"/\".join(library_directory.split(\"/\")[:-1])\n", + "library_directory = '/'.join(library_directory.split('/')[:-1])\n", "print(library_directory)\n", "if save_df:\n", " file = os.path.join(library_directory, name)\n", - " print(\"saving to \", file)\n", + " print('saving to ', file)\n", " df.to_csv(file)\n", "\n", " # df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" @@ -2065,21 +2056,21 @@ "source": [ "## prepare output for for CFM-ID\n", "save_df = False\n", - "cfm_directory = f\"{home}/data/metabolites/cfm-id/\"\n", - "name = \"casmi16t_negative_solutions_cfm.txt\"\n", - "df_cfm = df[[\"ChallengeName\", \"SMILES\", \"Precursor_type\"]]\n", + "cfm_directory = f'{home}/data/metabolites/cfm-id/'\n", + "name = 'casmi16t_negative_solutions_cfm.txt'\n", + "df_cfm = df[['ChallengeName', 'SMILES', 'Precursor_type']]\n", "print(df_cfm.head())\n", "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", - " df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"][[\"ChallengeName\", \"SMILES\"]].to_csv(\n", - " file, index=False, header=False, sep=\" \"\n", + " df_cfm[df_cfm['Precursor_type'] == '[M-H]-'][['ChallengeName', 'SMILES']].to_csv(\n", + " file, index=False, header=False, sep=' '\n", " )\n", "\n", - " name = name.replace(\"negative\", \"positive\")\n", + " name = name.replace('negative', 'positive')\n", " file = os.path.join(cfm_directory, name)\n", - " df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"][[\"ChallengeName\", \"SMILES\"]].to_csv(\n", - " file, index=False, header=False, sep=\" \"\n", + " df_cfm[df_cfm['Precursor_type'] == '[M+H]+'][['ChallengeName', 'SMILES']].to_csv(\n", + " file, index=False, header=False, sep=' '\n", " )" ] }, @@ -2093,16 +2084,16 @@ "from rdkit import Chem\n", "from rdkit.Chem import rdMolDescriptors\n", "\n", - "df[\"MOL\"] = df[\"SMILES\"].apply(Chem.MolFromSmiles)\n", - "df[\"formula\"] = df[\"MOL\"].apply(rdMolDescriptors.CalcMolFormula)\n", - "df[\"dataset\"] = \"CASMI16\"\n", + "df['MOL'] = df['SMILES'].apply(Chem.MolFromSmiles)\n", + "df['formula'] = df['MOL'].apply(rdMolDescriptors.CalcMolFormula)\n", + "df['dataset'] = 'CASMI16'\n", "df = df.rename(\n", " columns={\n", - " \"FILE\": \"spec\",\n", - " \"ChallengeName\": \"name\",\n", - " \"Precursor_type\": \"ionization\",\n", - " \"SMILES\": \"smiles\",\n", - " \"INCHIKEY\": \"inchikey\",\n", + " 'FILE': 'spec',\n", + " 'ChallengeName': 'name',\n", + " 'Precursor_type': 'ionization',\n", + " 'SMILES': 'smiles',\n", + " 'INCHIKEY': 'inchikey',\n", " }\n", ")" ] @@ -2115,14 +2106,14 @@ "source": [ "save_df = False\n", "if save_df:\n", - " output_file = f\"{home}/data/metabolites/ms-pred/casmi16t_labels.tsv\"\n", + " output_file = f'{home}/data/metabolites/ms-pred/casmi16t_labels.tsv'\n", " df[\n", - " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", - " ].to_csv(output_file, index=False, sep=\"\\t\")\n", - " output_file = f\"{home}/data/metabolites/ms-pred/casmi16t_positive_labels.tsv\"\n", - " df[df[\"ionization\"] == \"[M+H]+\"][\n", - " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", - " ].to_csv(output_file, index=False, sep=\"\\t\")" + " ['dataset', 'spec', 'name', 'formula', 'ionization', 'smiles', 'inchikey']\n", + " ].to_csv(output_file, index=False, sep='\\t')\n", + " output_file = f'{home}/data/metabolites/ms-pred/casmi16t_positive_labels.tsv'\n", + " df[df['ionization'] == '[M+H]+'][\n", + " ['dataset', 'spec', 'name', 'formula', 'ionization', 'smiles', 'inchikey']\n", + " ].to_csv(output_file, index=False, sep='\\t')" ] } ], diff --git a/lib_loader/casmi22_loader.ipynb b/lib_loader/casmi22_loader.ipynb index 648de06..e81969e 100644 --- a/lib_loader/casmi22_loader.ipynb +++ b/lib_loader/casmi22_loader.ipynb @@ -16,40 +16,36 @@ "source": [ "import sys\n", "\n", - "print(f\"Working with Python {sys.version}\")\n", + "print(f'Working with Python {sys.version}')\n", "\n", - "import numpy as np\n", - "import pandas as pd\n", + "import collections\n", "import importlib\n", + "import os\n", + "import time\n", "\n", "# import swifter\n", "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import collections\n", - "import time\n", - "import os\n", - "from rdkit import Chem\n", - "from rdkit.Chem import AllChem\n", - "from rdkit.Chem import Draw\n", - "import rdkit.Chem.Descriptors as Descriptors\n", - "from rdkit.Chem import PandasTools\n", - "from rdkit import RDLogger\n", + "import numpy as np\n", + "import pandas as pd\n", "import pymzml\n", "import pyteomics\n", + "import rdkit.Chem.Descriptors as Descriptors\n", + "import seaborn as sns\n", + "from rdkit import Chem, RDLogger\n", + "from rdkit.Chem import AllChem, Draw, PandasTools\n", "\n", "#\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "import modules.IO.mspReader as mspReader\n", + "home = expanduser('~')\n", "import modules.IO.mgfReader as mgfReader\n", - "import modules.visualization.spectrum_visualizer as sv\n", "import modules.IO.molReader as molReader\n", + "import modules.IO.mspReader as mspReader\n", + "import modules.visualization.spectrum_visualizer as sv\n", "\n", - "\n", - "RDLogger.DisableLog(\"rdApp.*\")" + "RDLogger.DisableLog('rdApp.*')" ] }, { @@ -74,23 +70,22 @@ "metadata": {}, "outputs": [], "source": [ - "from pyopenms import *\n", "from modules.MS.ms_utility import merge_spectrum, normalize_spectrum\n", - "import modules.visualization.spectrum_visualizer as sv\n", "from modules.MS.spectral_scores import spectral_cosine\n", + "from pyopenms import *\n", "\n", "\n", "def extract_compound_spectra(df, rt, precursor_mz, rt_tolerance=0.1, mz_tolerance=0.1):\n", - " df[\"RT_dif\"] = abs(df[\"RT_min\"] - rt)\n", - " r = df[\"RT_dif\"] < rt_tolerance\n", - " p = abs(df[\"precursor_mz\"] - precursor_mz) < mz_tolerance\n", + " df['RT_dif'] = abs(df['RT_min'] - rt)\n", + " r = df['RT_dif'] < rt_tolerance\n", + " p = abs(df['precursor_mz'] - precursor_mz) < mz_tolerance\n", "\n", " return df[np.logical_and(r, p)]\n", "\n", "\n", "def score_against_merged(df, ref_spec):\n", " for i, d in df.iterrows():\n", - " df.at[i, \"ref_score\"] = spectral_cosine(d[\"peaks\"], ref_spec, transform=np.sqrt)\n", + " df.at[i, 'ref_score'] = spectral_cosine(d['peaks'], ref_spec, transform=np.sqrt)\n", "\n", "\n", "def merge_df(DF):\n", @@ -100,18 +95,18 @@ " m = None\n", " for i in range(DF.shape[0]):\n", " if i == 0:\n", - " m = DF.iloc[i][\"peaks\"].copy()\n", + " m = DF.iloc[i]['peaks'].copy()\n", " else:\n", - " m = merge_spectrum(m, DF.iloc[i][\"peaks\"].copy(), merge_tolerance=0.01)\n", + " m = merge_spectrum(m, DF.iloc[i]['peaks'].copy(), merge_tolerance=0.01)\n", "\n", " M = pd.DataFrame(\n", " {\n", - " \"RT\": DF[\"RT\"].mean(),\n", - " \"RT_min\": DF[\"RT_min\"].mean(),\n", - " \"precursor_mz\": DF[\"precursor_mz\"].mean(),\n", - " \"Instrument_type\": \"HCD\",\n", - " \"NCE\": DF[\"NCE\"].mean(),\n", - " \"peaks\": [m],\n", + " 'RT': DF['RT'].mean(),\n", + " 'RT_min': DF['RT_min'].mean(),\n", + " 'precursor_mz': DF['precursor_mz'].mean(),\n", + " 'Instrument_type': 'HCD',\n", + " 'NCE': DF['NCE'].mean(),\n", + " 'peaks': [m],\n", " }\n", " )\n", " score_against_merged(DF, m)\n", @@ -133,21 +128,21 @@ "\n", " # Create dataframe with metadata\n", " df = exp.get_df()\n", - " df[\"RT_min\"] = df[\"RT\"] / 60.0\n", - " df[\"precursor_mz\"] = [\n", + " df['RT_min'] = df['RT'] / 60.0\n", + " df['precursor_mz'] = [\n", " spec.getAcquisitionInfo()[0].getMetaValue(\n", - " \"[Thermo Trailer Extra]Monoisotopic M/Z:\"\n", + " '[Thermo Trailer Extra]Monoisotopic M/Z:'\n", " )\n", " for spec in exp\n", " ]\n", - " df[\"ms_level\"] = [spec.getMSLevel() for spec in exp]\n", - " df[\"filter_string\"] = [spec.getMetaValue(\"filter string\") for spec in exp]\n", - " df = df[df[\"ms_level\"] == 2]\n", - " df[\"hcd\"] = df[\"filter_string\"].apply(lambda x: x.split(\"@\")[1].split(\" \")[0])\n", - " df[\"Instrument_type\"] = \"HCD\"\n", - " df[\"NCE\"] = df[\"hcd\"].apply(lambda x: x[3:]).astype(float)\n", - " df[\"peaks\"] = df.apply(\n", - " lambda x: {\"mz\": list(x[\"mzarray\"]), \"intensity\": list(x[\"intarray\"])}, axis=1\n", + " df['ms_level'] = [spec.getMSLevel() for spec in exp]\n", + " df['filter_string'] = [spec.getMetaValue('filter string') for spec in exp]\n", + " df = df[df['ms_level'] == 2]\n", + " df['hcd'] = df['filter_string'].apply(lambda x: x.split('@')[1].split(' ')[0])\n", + " df['Instrument_type'] = 'HCD'\n", + " df['NCE'] = df['hcd'].apply(lambda x: x[3:]).astype(float)\n", + " df['peaks'] = df.apply(\n", + " lambda x: {'mz': list(x['mzarray']), 'intensity': list(x['intarray'])}, axis=1\n", " )\n", "\n", " ## Extract compound spectra\n", @@ -157,15 +152,15 @@ " df_extract = extract_compound_spectra(\n", " df, rt, precursor_mz, rt_tolerance=rt_tolerance, mz_tolerance=mz_tolerance\n", " )\n", - " df_extract[\"peaks\"].apply(normalize_spectrum)\n", + " df_extract['peaks'].apply(normalize_spectrum)\n", "\n", " if verbose:\n", - " print(f\"Extracted {df_extract.shape[0]} spectra with match rt and mz\")\n", + " print(f'Extracted {df_extract.shape[0]} spectra with match rt and mz')\n", "\n", - " df_low = df_extract[df_extract[\"NCE\"] == 35.0].sort_values(\"RT_dif\", ascending=True)\n", - " df_med = df_extract[df_extract[\"NCE\"] == 45.0].sort_values(\"RT_dif\", ascending=True)\n", - " df_high = df_extract[df_extract[\"NCE\"] == 65.0].sort_values(\n", - " \"RT_dif\", ascending=True\n", + " df_low = df_extract[df_extract['NCE'] == 35.0].sort_values('RT_dif', ascending=True)\n", + " df_med = df_extract[df_extract['NCE'] == 45.0].sort_values('RT_dif', ascending=True)\n", + " df_high = df_extract[df_extract['NCE'] == 65.0].sort_values(\n", + " 'RT_dif', ascending=True\n", " )\n", "\n", " # print(df_low.shape[0],df_med.shape[0],df_high.shape[0] )\n", @@ -175,42 +170,42 @@ " if df_low.shape[0] > 0:\n", " low = merge_df(df_low)\n", " challenges = pd.concat([challenges, low])\n", - " ref_scores += list(df_low[\"ref_score\"].values)\n", + " ref_scores += list(df_low['ref_score'].values)\n", "\n", " if df_med.shape[0] > 0:\n", " med = merge_df(df_med)\n", " challenges = pd.concat([challenges, med])\n", - " ref_scores += list(df_med[\"ref_score\"].values)\n", + " ref_scores += list(df_med['ref_score'].values)\n", "\n", " if df_high.shape[0] > 0:\n", " high = merge_df(df_high)\n", " challenges = pd.concat([challenges, high])\n", - " ref_scores += list(df_high[\"ref_score\"].values)\n", + " ref_scores += list(df_high['ref_score'].values)\n", "\n", " # print(challenges)\n", " if challenges.shape[0] == 0:\n", - " print(\"Warning: Could not extract challenge!!! No matches found.\")\n", + " print('Warning: Could not extract challenge!!! No matches found.')\n", " else:\n", " min_ref_score = min(ref_scores)\n", " if min_ref_score < 0.9:\n", " print(\n", - " f\"Warning: Low cosine score of {min_ref_score:.2f} detected. (All ref_scores {ref_scores})\"\n", + " f'Warning: Low cosine score of {min_ref_score:.2f} detected. (All ref_scores {ref_scores})'\n", " )\n", " # raise Warning(\"Low cosine score detected between merged spectrum and at least one experimental spectrum. Tolerance values might have picked up a false RT/MZ match\")\n", "\n", " if verbose:\n", - " print(f\"\\nMerged {df_low.shape[0]} with NCE 35\")\n", - " print(df_low[\"ref_score\"])\n", - " print(f\"\\nMerged {df_med.shape[0]} with NCE 45\")\n", - " print(df_med[\"ref_score\"])\n", - " print(f\"\\nMerged {df_high.shape[0]} with NCE 65\")\n", - " print(df_high[\"ref_score\"])\n", + " print(f'\\nMerged {df_low.shape[0]} with NCE 35')\n", + " print(df_low['ref_score'])\n", + " print(f'\\nMerged {df_med.shape[0]} with NCE 45')\n", + " print(df_med['ref_score'])\n", + " print(f'\\nMerged {df_high.shape[0]} with NCE 65')\n", + " print(df_high['ref_score'])\n", " sv.plot_spectrum(df_low.iloc[0], challenges.iloc[0])\n", " sv.plot_spectrum(df_med.iloc[0], challenges.iloc[1])\n", " sv.plot_spectrum(df_high.iloc[0], challenges.iloc[2])\n", " plt.show()\n", "\n", - " print(f\"Minimum cosine score to merged spectra: {min_ref_score:.2f} (pass)\")\n", + " print(f'Minimum cosine score to merged spectra: {min_ref_score:.2f} (pass)')\n", "\n", " return challenges, ref_scores" ] @@ -321,8 +316,8 @@ } ], "source": [ - "l = f\"{home}/data/metabolites/CASMI_2022/download/\"\n", - "f = \"A_M3_posPFP_01.mzml\"\n", + "l = f'{home}/data/metabolites/CASMI_2022/download/'\n", + "f = 'A_M3_posPFP_01.mzml'\n", "path = os.path.join(l, f)\n", "\n", "\n", @@ -460,9 +455,9 @@ } ], "source": [ - "l = f\"{home}/data/metabolites/CASMI_2022/\"\n", - "key_file = \"MetSoc2022_CASMI_Workshop_Challenges_KEY_ALL_FINAL.csv\"\n", - "challenge_key = pd.read_csv(os.path.join(l, key_file), sep=\"\\t\")\n", + "l = f'{home}/data/metabolites/CASMI_2022/'\n", + "key_file = 'MetSoc2022_CASMI_Workshop_Challenges_KEY_ALL_FINAL.csv'\n", + "challenge_key = pd.read_csv(os.path.join(l, key_file), sep='\\t')\n", "challenge_key.head(3)" ] }, @@ -618,13 +613,13 @@ "source": [ "challenge_key = challenge_key[\n", " np.logical_or(\n", - " challenge_key[\"Adduct\"] == \"[M+H]+\", challenge_key[\"Adduct\"] == \"[M-H]-\"\n", + " challenge_key['Adduct'] == '[M+H]+', challenge_key['Adduct'] == '[M-H]-'\n", " )\n", "]\n", "\n", "print(challenge_key.shape)\n", "\n", - "sns.histplot(challenge_key, x=\"Adduct\")\n", + "sns.histplot(challenge_key, x='Adduct')\n", "plt.show()\n", "challenge_key.head(3)" ] @@ -790,14 +785,14 @@ "\n", "\n", "for i, d in challenge_key.iterrows():\n", - " file = d[\"File\"] + \".mzml\"\n", - " rt = d[\"RT [min]\"]\n", - " precursor_mz = d[\"Precursor m/z (Da)\"]\n", - " adduct = d[\"Adduct\"]\n", - " name = d[\"Compound Number\"]\n", - " smiles = d[\"SMILES\"].strip()\n", + " file = d['File'] + '.mzml'\n", + " rt = d['RT [min]']\n", + " precursor_mz = d['Precursor m/z (Da)']\n", + " adduct = d['Adduct']\n", + " name = d['Compound Number']\n", + " smiles = d['SMILES'].strip()\n", "\n", - " path = os.path.join(l, \"mzml/\", file)\n", + " path = os.path.join(l, 'mzml/', file)\n", "\n", " c, ref_scores = extract_challenge(\n", " path,\n", @@ -809,16 +804,16 @@ " verbose=False,\n", " ) # 5 seconds torance, 10 ppm precursor mass\n", " if c.shape[0] == 0:\n", - " print(f\"No match found for Compound Number {name}\")\n", + " print(f'No match found for Compound Number {name}')\n", " misses += 1\n", " continue\n", " avg_ref_score += ref_scores\n", " avg_min_ref_score += [min(ref_scores)]\n", " avg_num_spectra += [len(ref_scores)]\n", - " c[\"Precursor_type\"] = adduct\n", - " c[\"ChallengeName\"] = \"Challenge-\" + str(name)\n", - " c[\"ChallengeRT\"] = rt\n", - " c[\"SMILES\"] = smiles\n", + " c['Precursor_type'] = adduct\n", + " c['ChallengeName'] = 'Challenge-' + str(name)\n", + " c['ChallengeRT'] = rt\n", + " c['SMILES'] = smiles\n", " challenges = pd.concat([challenges, c], axis=0)" ] }, @@ -842,8 +837,8 @@ } ], "source": [ - "challenges[challenges[\"ChallengeName\"] == \"Challenge-277\"][\n", - " \"SMILES\"\n", + "challenges[challenges['ChallengeName'] == 'Challenge-277'][\n", + " 'SMILES'\n", "] # .apply(lambda x: x.strip())" ] }, @@ -1011,10 +1006,10 @@ } ], "source": [ - "print(f\"Missed compounds: {misses}\")\n", - "print(f\"Mean ref score: {np.mean(avg_ref_score)}\")\n", - "print(f\"Mean min score: {np.mean(avg_min_ref_score)}\")\n", - "print(f\"Median num spec: {np.median(avg_num_spectra)}\")" + "print(f'Missed compounds: {misses}')\n", + "print(f'Mean ref score: {np.mean(avg_ref_score)}')\n", + "print(f'Mean min score: {np.mean(avg_min_ref_score)}')\n", + "print(f'Median num spec: {np.median(avg_num_spectra)}')" ] }, { @@ -1023,7 +1018,7 @@ "metadata": {}, "outputs": [], "source": [ - "raise KeyboardInterrupt(\"Stop. Make sure you want to save the DataFrames\")" + "raise KeyboardInterrupt('Stop. Make sure you want to save the DataFrames')" ] }, { @@ -1041,11 +1036,11 @@ "outputs": [], "source": [ "save_df = True\n", - "name = \"casmi22_challenges_combined_accurate.csv\"\n", + "name = 'casmi22_challenges_combined_accurate.csv'\n", "\n", "if save_df:\n", " file = os.path.join(l, name)\n", - " print(\"saving to \", file)\n", + " print('saving to ', file)\n", " challenges.to_csv(file)\n", "\n", " # df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" @@ -1067,23 +1062,23 @@ "source": [ "## prepare output for for CFM-ID\n", "save_df = False\n", - "cfm_directory = f\"{home}/data/metabolites/cfm-id/\"\n", - "name = \"casmi22_negative_solutions_cfm.txt\"\n", - "unique_challenges = challenges.drop_duplicates(subset=\"ChallengeName\", keep=\"first\")\n", + "cfm_directory = f'{home}/data/metabolites/cfm-id/'\n", + "name = 'casmi22_negative_solutions_cfm.txt'\n", + "unique_challenges = challenges.drop_duplicates(subset='ChallengeName', keep='first')\n", "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", - " print(\"saving to \", file)\n", - " unique_challenges[unique_challenges[\"Precursor_type\"] == \"[M-H]-\"][\n", - " [\"ChallengeName\", \"SMILES\"]\n", - " ].to_csv(file, index=False, header=False, sep=\" \")\n", + " print('saving to ', file)\n", + " unique_challenges[unique_challenges['Precursor_type'] == '[M-H]-'][\n", + " ['ChallengeName', 'SMILES']\n", + " ].to_csv(file, index=False, header=False, sep=' ')\n", "\n", - " name = name.replace(\"negative\", \"positive\")\n", + " name = name.replace('negative', 'positive')\n", " file = os.path.join(cfm_directory, name)\n", - " print(\"saving to \", file)\n", - " unique_challenges[unique_challenges[\"Precursor_type\"] == \"[M+H]+\"][\n", - " [\"ChallengeName\", \"SMILES\"]\n", - " ].to_csv(file, index=False, header=False, sep=\" \")" + " print('saving to ', file)\n", + " unique_challenges[unique_challenges['Precursor_type'] == '[M+H]+'][\n", + " ['ChallengeName', 'SMILES']\n", + " ].to_csv(file, index=False, header=False, sep=' ')" ] }, { @@ -1092,7 +1087,7 @@ "metadata": {}, "outputs": [], "source": [ - "raise KeyboardInterrupt(\"Until here. AND NO FURTHER\")" + "raise KeyboardInterrupt('Until here. AND NO FURTHER')" ] }, { @@ -1124,7 +1119,7 @@ } ], "source": [ - "sns.displot(challenge_key, x=\"RT [min]\")" + "sns.displot(challenge_key, x='RT [min]')" ] }, { @@ -1133,7 +1128,7 @@ "metadata": {}, "outputs": [], "source": [ - "raise ValueError(\"STOP HERE\")" + "raise ValueError('STOP HERE')" ] }, { @@ -1165,7 +1160,7 @@ "source": [ "comp = [46, 64, 78, 79, 118, 130, 163, 165]\n", "\n", - "print(challenge_key[challenge_key[\"Compound Number\"] == 78])\n", + "print(challenge_key[challenge_key['Compound Number'] == 78])\n", "\n", "# print(extract_challenge(249.1496))" ] @@ -1196,10 +1191,10 @@ } ], "source": [ - "peaks1 = {\"mz\": [2.5, 2.8, 5], \"intensity\": [8, 5, 9]}\n", - "peaks2 = {\"mz\": [2.51, 3.0, 10], \"intensity\": [2, 2, 2]}\n", - "normalize_spectrum(peaks1, \"norm\")\n", - "normalize_spectrum(peaks2, \"norm\")\n", + "peaks1 = {'mz': [2.5, 2.8, 5], 'intensity': [8, 5, 9]}\n", + "peaks2 = {'mz': [2.51, 3.0, 10], 'intensity': [2, 2, 2]}\n", + "normalize_spectrum(peaks1, 'norm')\n", + "normalize_spectrum(peaks2, 'norm')\n", "\n", "mm = merge_spectrum(peaks1, peaks2, merge_tolerance=0.01)\n", "print(peaks1)\n", @@ -1302,8 +1297,8 @@ ], "source": [ "vec, vec_ref = np.zeros(len(mz_map)), np.zeros(len(mz_map))\n", - "np.add.at(vec, bins, peaks1[\"intensity\"]) # vec.put(bins, spec[\"intensity\"])\n", - "np.add.at(vec_ref, bins_ref, mm[\"intensity\"])\n", + "np.add.at(vec, bins, peaks1['intensity']) # vec.put(bins, spec[\"intensity\"])\n", + "np.add.at(vec_ref, bins_ref, mm['intensity'])\n", "\n", "print(vec)\n", "print(vec_ref)" @@ -1355,8 +1350,8 @@ } ], "source": [ - "print(df_low.iloc[0][\"peaks\"][\"mz\"][-10:])\n", - "print(df_low.iloc[1][\"peaks\"][\"mz\"][-10:])" + "print(df_low.iloc[0]['peaks']['mz'][-10:])\n", + "print(df_low.iloc[1]['peaks']['mz'][-10:])" ] }, { @@ -1412,7 +1407,7 @@ ], "source": [ "for i in spec.getAcquisitionInfo():\n", - " print(i.getMetaValue(\"[Thermo Trailer Extra]Monoisotopic M/Z:\"))" + " print(i.getMetaValue('[Thermo Trailer Extra]Monoisotopic M/Z:'))" ] }, { @@ -1432,7 +1427,7 @@ } ], "source": [ - "spec.getAcquisitionInfo()[0].getMetaValue(\"[Thermo Trailer Extra]Monoisotopic M/Z:\")" + "spec.getAcquisitionInfo()[0].getMetaValue('[Thermo Trailer Extra]Monoisotopic M/Z:')" ] }, { @@ -1477,18 +1472,18 @@ ], "source": [ "# l = f\"{home}/data/metabolites/CASMI_2022/3_Data-20220516T091400Z-001/3_Data/mzML Data/1_Priority - Challenges 1-250/pos\"\n", - "l = f\"{home}/data/metabolites/CASMI_2022/download/\"\n", - "f = \"A_M3_posPFP_01.mzml\"\n", + "l = f'{home}/data/metabolites/CASMI_2022/download/'\n", + "f = 'A_M3_posPFP_01.mzml'\n", "\n", "run = pymzml.run.Reader(os.path.join(l, f))\n", "\n", "for spec in run:\n", - " mz, intensity = spec.peaks(\"centroided\")[:, 0], spec.peaks(\"centroided\")[:, 1]\n", + " mz, intensity = spec.peaks('centroided')[:, 0], spec.peaks('centroided')[:, 1]\n", " print(spec.scan_time)\n", " for i in spec.__dir__():\n", - " if \"filter\" in i:\n", + " if 'filter' in i:\n", " print(i)\n", - " sv.plot_spectrum({\"peaks\": {\"mz\": mz, \"intensity\": intensity}})\n", + " sv.plot_spectrum({'peaks': {'mz': mz, 'intensity': intensity}})\n", " plt.show()\n", " break" ] @@ -1536,7 +1531,7 @@ } ], "source": [ - "from pyteomics import mzml, auxiliary, mgf\n", + "from pyteomics import auxiliary, mgf, mzml\n", "\n", "path = os.path.join(l, f)\n", "reader = mzml.read(path)\n", @@ -1576,8 +1571,8 @@ "metadata": {}, "outputs": [], "source": [ - "from matchms.importing import load_from_mzml\n", "from matchms.filtering import default_filters\n", + "from matchms.importing import load_from_mzml\n", "\n", "reader = load_from_mzml(path)" ] @@ -1626,7 +1621,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = pd.DataFrame([{\"spectrum\": s, **s.metadata} for s in spectrums])" + "df = pd.DataFrame([{'spectrum': s, **s.metadata} for s in spectrums])" ] }, { @@ -1643,13 +1638,13 @@ } ], "source": [ - "df[\"RT\"] = df.scan_start_time.apply(lambda x: x[0])\n", - "df[\"is_rt_match\"] = abs(df[\"RT\"] - 5.55) < 0.1\n", + "df['RT'] = df.scan_start_time.apply(lambda x: x[0])\n", + "df['is_rt_match'] = abs(df['RT'] - 5.55) < 0.1\n", "\n", - "df[\"is_mz_match\"] = abs(df[\"precursor_mz\"] - 719.2546) < 0.1\n", - "df[\"is_match\"] = np.logical_and(df[\"is_rt_match\"], df[\"is_mz_match\"])\n", - "example = df[df[\"is_match\"]].iloc[0]\n", - "spec = example[\"spectrum\"]\n", + "df['is_mz_match'] = abs(df['precursor_mz'] - 719.2546) < 0.1\n", + "df['is_match'] = np.logical_and(df['is_rt_match'], df['is_mz_match'])\n", + "example = df[df['is_match']].iloc[0]\n", + "spec = example['spectrum']\n", "print()" ] }, @@ -1881,7 +1876,6 @@ "source": [ "from pyopenms import *\n", "\n", - "\n", "print(open(path).readlines()[197].strip())\n", "print(path)" ] @@ -1994,7 +1988,7 @@ } ], "source": [ - "getters = [x for x in exp.getSpectrum(1).__dir__() if \"get\" in x]\n", + "getters = [x for x in exp.getSpectrum(1).__dir__() if 'get' in x]\n", "spec1 = None\n", "\n", "for spec in exp:\n", @@ -2020,7 +2014,7 @@ } ], "source": [ - "print(spec1.getMetaValue(\"filter string\"))" + "print(spec1.getMetaValue('filter string'))" ] }, { @@ -2667,7 +2661,7 @@ "\n", "for line in lines:\n", " line = line.strip()\n", - " if \"scan=2672\" in line:\n", + " if 'scan=2672' in line:\n", " print(line)\n", " # if 'name=\"ms level\" value=\"2\"' in line: print(line)" ] @@ -2710,11 +2704,11 @@ } ], "source": [ - "print(spec.getMetaValue(\"base peak m/z\"))\n", - "print(spec.getMetaValue(\"base peak intensity\"))\n", - "print(spec.getMetaValue(\"total ion current\"))\n", - "print(spec.getMetaValue(\"filter string\"))\n", - "print(spec.getMetaValue(\"[Thermo Trailer Extra]Monoisotopic M/Z:\"))\n", + "print(spec.getMetaValue('base peak m/z'))\n", + "print(spec.getMetaValue('base peak intensity'))\n", + "print(spec.getMetaValue('total ion current'))\n", + "print(spec.getMetaValue('filter string'))\n", + "print(spec.getMetaValue('[Thermo Trailer Extra]Monoisotopic M/Z:'))\n", "print(spec.getMSLevel())\n", "print(spec.getMSLevel() == 2)" ] @@ -3219,7 +3213,7 @@ } ], "source": [ - "print(spec.getAcquisitionInfo().getMetaValue(\"filter string\"))" + "print(spec.getAcquisitionInfo().getMetaValue('filter string'))" ] }, { @@ -3246,7 +3240,7 @@ "for spec in exp:\n", " if spec.getMSLevel() == 2:\n", " if abs(spec.getRT() - 5.55) < 0.1:\n", - " print(\"yes\")" + " print('yes')" ] } ], diff --git a/lib_loader/gnps_library_loader.ipynb b/lib_loader/gnps_library_loader.ipynb index 78fc9f3..f7ca622 100644 --- a/lib_loader/gnps_library_loader.ipynb +++ b/lib_loader/gnps_library_loader.ipynb @@ -32,54 +32,46 @@ "source": [ "import sys\n", "\n", - "print(f\"Working with Python {sys.version}\")\n", + "print(f'Working with Python {sys.version}')\n", "\n", - "import numpy as np\n", - "import pandas as pd\n", + "import collections\n", "import importlib\n", + "import os\n", + "import time\n", "\n", "# import swifter\n", "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import collections\n", - "import time\n", - "import os\n", - "from rdkit import Chem\n", - "from rdkit.Chem import AllChem\n", - "from rdkit.Chem import Draw\n", + "import numpy as np\n", + "import pandas as pd\n", "import rdkit.Chem.Descriptors as Descriptors\n", - "from rdkit.Chem import PandasTools\n", - "#!pip install spektral\n", - "\n", + "import seaborn as sns\n", "\n", + "#!pip install spektral\n", "# Deep Learning\n", "import sklearn\n", "\n", - "# import spektral\n", - "from sklearn.model_selection import train_test_split\n", - "\n", "# Keras\n", - "from sklearn.model_selection import train_test_split\n", - "\n", "# import stellargraph as sg\n", - "from rdkit import RDLogger\n", + "from rdkit import Chem, RDLogger\n", + "from rdkit.Chem import AllChem, Draw, PandasTools\n", "\n", + "# import spektral\n", + "from sklearn.model_selection import train_test_split\n", "\n", "#\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", + "home = expanduser('~')\n", + "import fiora.IO.molReader as molReader\n", "import fiora.IO.mspReader as mspReader\n", "import fiora.visualization.spectrum_visualizer as sv\n", - "import fiora.IO.molReader as molReader\n", "\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", "\n", - "\n", - "caffeine_smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", + "caffeine_smiles = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'\n", "caffeine_mol = Chem.MolFromSmiles(caffeine_smiles)\n", "\n", "caffeine_mol" @@ -99,8 +91,8 @@ } ], "source": [ - "library_name = \"cleaned_spectra.mgf\"\n", - "library_directory = f\"{home}/data/metabolites/GNPS/\"\n", + "library_name = 'cleaned_spectra.mgf'\n", + "library_directory = f'{home}/data/metabolites/GNPS/'\n", "!ls $library_directory" ] }, @@ -137,7 +129,7 @@ } ], "source": [ - "df.columns #" + "df.columns" ] }, { @@ -157,7 +149,7 @@ } ], "source": [ - "sum(~df[\"COLLISION_ENERGY\"].isna())" + "sum(~df['COLLISION_ENERGY'].isna())" ] }, { @@ -166,7 +158,7 @@ "metadata": {}, "outputs": [], "source": [ - "df[\"COLLISION_ENERGY\"] = df[\"COLLISION_ENERGY\"].astype(float)" + "df['COLLISION_ENERGY'] = df['COLLISION_ENERGY'].astype(float)" ] }, { @@ -198,7 +190,7 @@ "source": [ "# TODO ANALYSE DATASET\n", "\n", - "sns.histplot(data=df, x=\"COLLISION_ENERGY\", hue=\"MS_MANUFACTURER\", multiple=\"dodge\")" + "sns.histplot(data=df, x='COLLISION_ENERGY', hue='MS_MANUFACTURER', multiple='dodge')" ] }, { @@ -225,7 +217,7 @@ } ], "source": [ - "df[\"MS_MANUFACTURER\"].value_counts()" + "df['MS_MANUFACTURER'].value_counts()" ] }, { @@ -246,7 +238,7 @@ ], "source": [ "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", + "\n", "from fiora.visualization.define_colors import *\n", "\n", "set_light_theme()\n", @@ -256,24 +248,24 @@ "\n", "# Create a FacetGrid with each subplot for a different manufacturer\n", "g = sns.FacetGrid(\n", - " df.dropna(subset=[\"COLLISION_ENERGY\"]),\n", - " col=\"MS_MANUFACTURER\",\n", + " df.dropna(subset=['COLLISION_ENERGY']),\n", + " col='MS_MANUFACTURER',\n", " col_wrap=3,\n", " sharex=True,\n", " sharey=False,\n", ")\n", "g.map(\n", " sns.histplot,\n", - " \"COLLISION_ENERGY\",\n", + " 'COLLISION_ENERGY',\n", " bins=range(0, 101, 10),\n", " kde=False,\n", - " color=\"gray\",\n", - " edgecolor=\"black\",\n", + " color='gray',\n", + " edgecolor='black',\n", ")\n", "\n", "# Set common labels and titles\n", - "g.set_axis_labels(\"Collision Energy\", \"Count\")\n", - "g.set_titles(\"{col_name}\")\n", + "g.set_axis_labels('Collision Energy', 'Count')\n", + "g.set_titles('{col_name}')\n", "\n", "# Set x-axis limits for all subplots\n", "g.set(xlim=(0, 100))\n", @@ -304,8 +296,8 @@ } ], "source": [ - "library_name = \"ALL_GNPS\"\n", - "library_directory = f\"{home}/data/metabolites/GNPS/download_06_23/\"\n", + "library_name = 'ALL_GNPS'\n", + "library_directory = f'{home}/data/metabolites/GNPS/download_06_23/'\n", "!ls $library_directory" ] }, @@ -317,7 +309,7 @@ "source": [ "import json\n", "\n", - "f = open(library_directory + library_name + \".json\")\n", + "f = open(library_directory + library_name + '.json')\n", "js = json.load(f)\n", "df = pd.DataFrame(js)" ] @@ -364,7 +356,7 @@ "metadata": {}, "outputs": [], "source": [ - "ms = df[df[\"library_membership\"] == \"MASSBANK\"]" + "ms = df[df['library_membership'] == 'MASSBANK']" ] }, { @@ -405,7 +397,7 @@ } ], "source": [ - "ms.iloc[0][\"annotation_history\"]" + "ms.iloc[0]['annotation_history']" ] }, { @@ -426,7 +418,7 @@ } ], "source": [ - "a = \"NC1=C2C=CC(=CC2=CC=C1)S(O)(=O)=O\"\n", + "a = 'NC1=C2C=CC(=CC2=CC=C1)S(O)(=O)=O'\n", "\n", "Chem.MolFromSmiles(a)" ] @@ -474,13 +466,13 @@ } ], "source": [ - "nist_msp = mspReader.read(library_directory + library_name + \".MSP\")\n", + "nist_msp = mspReader.read(library_directory + library_name + '.MSP')\n", "df_nist = pd.DataFrame(nist_msp)\n", "\n", "# df_nist['mol'] = df_nist['SMILES'].apply(Chem.MolFromSmiles)\n", "# df_nist.dropna(inplace=True)\n", "print(\n", - " f\"Spectral file loaded with {df_nist.shape[0]} entries and {df_nist.shape[1]} variables\"\n", + " f'Spectral file loaded with {df_nist.shape[0]} entries and {df_nist.shape[1]} variables'\n", ")" ] }, @@ -552,7 +544,7 @@ ], "source": [ "# Example\n", - "example_entry = \"Desipramine\"\n", + "example_entry = 'Desipramine'\n", "# x = df_nist[df_nist[\"Name\"] == example_entry].iloc[0]\n", "EXAMPLE_ID = 32271\n", "x = df_nist.loc[EXAMPLE_ID]\n", @@ -578,18 +570,18 @@ "outputs": [], "source": [ "# Define figure styles\n", - "color_palette = sns.color_palette(\"magma_r\", 8)\n", + "color_palette = sns.color_palette('magma_r', 8)\n", "sns.set_theme(\n", - " style=\"whitegrid\",\n", + " style='whitegrid',\n", " rc={\n", - " \"axes.edgecolor\": \"black\",\n", - " \"ytick.left\": True,\n", - " \"xtick.bottom\": True,\n", - " \"xtick.color\": \"black\",\n", - " \"axes.spines.bottom\": True,\n", - " \"axes.spines.right\": True,\n", - " \"axes.spines.top\": True,\n", - " \"axes.spines.left\": True,\n", + " 'axes.edgecolor': 'black',\n", + " 'ytick.left': True,\n", + " 'xtick.bottom': True,\n", + " 'xtick.color': 'black',\n", + " 'axes.spines.bottom': True,\n", + " 'axes.spines.right': True,\n", + " 'axes.spines.top': True,\n", + " 'axes.spines.left': True,\n", " },\n", ")" ] @@ -614,25 +606,25 @@ ], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 2]}, sharey=True\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 2]}, sharey=True\n", ")\n", "fig.set_tight_layout(False)\n", "for ax in axs:\n", - " ax.tick_params(\"x\", labelrotation=45)\n", + " ax.tick_params('x', labelrotation=45)\n", "\n", "sns.countplot(\n", - " ax=axs[0], data=df_nist, x=\"Spectrum_type\", edgecolor=\"black\", palette=color_palette\n", + " ax=axs[0], data=df_nist, x='Spectrum_type', edgecolor='black', palette=color_palette\n", ")\n", "sns.countplot(\n", " ax=axs[1],\n", " data=df_nist,\n", - " x=\"Precursor_type\",\n", - " edgecolor=\"black\",\n", + " x='Precursor_type',\n", + " edgecolor='black',\n", " palette=color_palette,\n", - " order=df_nist[\"Precursor_type\"].value_counts().iloc[:8].index,\n", + " order=df_nist['Precursor_type'].value_counts().iloc[:8].index,\n", ")\n", "axs[0].set_ylim(0, 500000)\n", - "axs[1].set_ylabel(\"\")\n", + "axs[1].set_ylabel('')\n", "\n", "plt.show()" ] @@ -652,18 +644,18 @@ ], "source": [ "# Filters\n", - "df_nist = df_nist[df_nist[\"Spectrum_type\"] == \"MS2\"]\n", - "target_precursor_type = [\"[M+H]+\", \"[M-H]-\", \"[M+H-H2O]+\", \"[M+Na]+\"]\n", + "df_nist = df_nist[df_nist['Spectrum_type'] == 'MS2']\n", + "target_precursor_type = ['[M+H]+', '[M-H]-', '[M+H-H2O]+', '[M+Na]+']\n", "df_nist = df_nist[\n", - " df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in target_precursor_type)\n", + " df_nist['Precursor_type'].apply(lambda ptype: ptype in target_precursor_type)\n", "]\n", "\n", "# Formats\n", - "df_nist[\"PrecursorMZ\"] = df_nist[\"PrecursorMZ\"].astype(\"float\")\n", - "df_nist[\"Num peaks\"] = df_nist[\"Num peaks\"].astype(\"int\")\n", + "df_nist['PrecursorMZ'] = df_nist['PrecursorMZ'].astype('float')\n", + "df_nist['Num peaks'] = df_nist['Num peaks'].astype('int')\n", "\n", "\n", - "print(f\"Spectral file filtered down to {df_nist.shape[0]} entries\")" + "print(f'Spectral file filtered down to {df_nist.shape[0]} entries')" ] }, { @@ -686,27 +678,27 @@ ], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 2]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 2]}, sharey=False\n", ")\n", "for ax in axs:\n", - " ax.tick_params(\"x\", labelrotation=45)\n", + " ax.tick_params('x', labelrotation=45)\n", "\n", "sns.boxplot(\n", - " ax=axs[0], data=df_nist, y=\"PrecursorMZ\", palette=color_palette, x=\"Precursor_type\"\n", + " ax=axs[0], data=df_nist, y='PrecursorMZ', palette=color_palette, x='Precursor_type'\n", ")\n", "sns.histplot(\n", " ax=axs[1],\n", " data=df_nist,\n", - " x=\"Num peaks\",\n", + " x='Num peaks',\n", " color=color_palette[7],\n", " fill=True,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", ") # , order=list(range(0,200)))\n", - "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", - "axs[1].set_ylabel(\"\")\n", + "plt.rcParams['patch.force_edgecolor'] = True\n", + "plt.rcParams['axes.edgecolor'] = 'black'\n", + "axs[1].set_ylabel('')\n", "axs[1].set_xlim([0, 100])\n", - "axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", + "axs[1].axvline(x=50, color='red', linestyle='-.')\n", "\n", "plt.show()" ] @@ -732,24 +724,24 @@ "source": [ "# associate MOL structures with MS2 spectra\n", "\n", - "file = library_directory + library_name + \".MOL/\" + \"S\" + x[\"CASNO\"] + \".MOL\"\n", + "file = library_directory + library_name + '.MOL/' + 'S' + x['CASNO'] + '.MOL'\n", "x_mol = molReader.load_MOL(file)\n", "x_mol\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].set_title(x['Name'] + ' structure:\\n' + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", - "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])" + "axs[0].axis('off')\n", + "sv.plot_spectrum(title=x['Name'] + ' MS/MS spectrum', spectrum=x, ax=axs[1])" ] }, { @@ -772,34 +764,34 @@ "# print(df_nist.loc[1474])\n", "\n", "print(\n", - " \"Reading structure information in MOL format from library files (this may take a while)\"\n", + " 'Reading structure information in MOL format from library files (this may take a while)'\n", ")\n", "\n", "\n", "def fetch_mol(data):\n", " file = (\n", - " library_directory + library_name + \".MOL/\" + \"S\" + str(data[\"CASNO\"]) + \".MOL\"\n", + " library_directory + library_name + '.MOL/' + 'S' + str(data['CASNO']) + '.MOL'\n", " )\n", " if not os.path.exists(file):\n", " file = (\n", - " library_directory + library_name + \".MOL/\" + \"ID\" + str(data[\"ID\"]) + \".MOL\"\n", + " library_directory + library_name + '.MOL/' + 'ID' + str(data['ID']) + '.MOL'\n", " )\n", " return molReader.load_MOL(file)\n", "\n", "\n", "df_nist = df_nist[\n", - " ~df_nist[\"InChIKey\"].isnull()\n", + " ~df_nist['InChIKey'].isnull()\n", "] # Drop all without key (Not neccessarily neccesary)\n", - "df_nist[\"MOL\"] = df_nist.apply(fetch_mol, axis=1)\n", + "df_nist['MOL'] = df_nist.apply(fetch_mol, axis=1)\n", "print(\n", - " f\"Successfully interpreted {sum(df_nist['MOL'].notna())} from {df_nist.shape[0]} entries. Dropping the rest.\"\n", + " f'Successfully interpreted {sum(df_nist[\"MOL\"].notna())} from {df_nist.shape[0]} entries. Dropping the rest.'\n", ")\n", "\n", - "df_nist = df_nist[df_nist[\"MOL\"].notna()]\n", - "df_nist[\"SMILES\"] = df_nist[\"MOL\"].apply(Chem.MolToSmiles)\n", - "df_nist[\"InChI\"] = df_nist[\"MOL\"].apply(Chem.MolToInchi)\n", - "df_nist[\"K\"] = df_nist[\"MOL\"].apply(Chem.MolToInchiKey)\n", - "df_nist[\"ExactMolWeight\"] = df_nist[\"MOL\"].apply(Chem.Descriptors.ExactMolWt)\n", + "df_nist = df_nist[df_nist['MOL'].notna()]\n", + "df_nist['SMILES'] = df_nist['MOL'].apply(Chem.MolToSmiles)\n", + "df_nist['InChI'] = df_nist['MOL'].apply(Chem.MolToInchi)\n", + "df_nist['K'] = df_nist['MOL'].apply(Chem.MolToInchiKey)\n", + "df_nist['ExactMolWeight'] = df_nist['MOL'].apply(Chem.Descriptors.ExactMolWt)\n", "\n", "# for i in df_nist.index:\n", "# tight_layout\n", @@ -871,14 +863,14 @@ } ], "source": [ - "print(df_nist[df_nist[\"Name\"] == example_entry].iloc[0])\n", - "print(len(df_nist[\"MOL\"].unique()))\n", + "print(df_nist[df_nist['Name'] == example_entry].iloc[0])\n", + "print(len(df_nist['MOL'].unique()))\n", "\n", "\n", "df_nist.shape\n", - "df_nist[\"SMILES\"].isnull().any()\n", + "df_nist['SMILES'].isnull().any()\n", "\n", - "df_nist[df_nist[\"Name\"] == example_entry].iloc[0][\"SMILES\"]" + "df_nist[df_nist['Name'] == example_entry].iloc[0]['SMILES']" ] }, { @@ -898,24 +890,24 @@ } ], "source": [ - "correct_keys = df_nist.apply(lambda x: x[\"InChIKey\"] == x[\"K\"], axis=1)\n", - "s = \"confirmed!\" if correct_keys.all() else \"not confirmed !! Attention!\"\n", + "correct_keys = df_nist.apply(lambda x: x['InChIKey'] == x['K'], axis=1)\n", + "s = 'confirmed!' if correct_keys.all() else 'not confirmed !! Attention!'\n", "print(\n", - " f\"Confirming whether computed and provided InChI-Keys are correct. Result: {s} ({correct_keys.sum() / len(correct_keys):0.2f} correct)\"\n", + " f'Confirming whether computed and provided InChI-Keys are correct. Result: {s} ({correct_keys.sum() / len(correct_keys):0.2f} correct)'\n", ")\n", "half_keys = df_nist.apply(\n", - " lambda x: x[\"InChIKey\"].split(\"-\")[0] == x[\"K\"].split(\"-\")[0], axis=1\n", + " lambda x: x['InChIKey'].split('-')[0] == x['K'].split('-')[0], axis=1\n", ")\n", - "s = \"confirmed!\" if half_keys.all() else \"not confirmed !! Attention!\"\n", + "s = 'confirmed!' if half_keys.all() else 'not confirmed !! Attention!'\n", "print(\n", - " f\"Checking if main layer InChI-Keys are correct. Result: {s} ({half_keys.sum() / len(half_keys):0.3f} correct)\"\n", + " f'Checking if main layer InChI-Keys are correct. Result: {s} ({half_keys.sum() / len(half_keys):0.3f} correct)'\n", ")\n", "\n", - "print(\"Dropping all other.\")\n", - "df_nist[\"matching_key\"] = df_nist.apply(lambda x: x[\"InChIKey\"] == x[\"K\"], axis=1)\n", - "df_nist = df_nist[df_nist[\"matching_key\"]]\n", + "print('Dropping all other.')\n", + "df_nist['matching_key'] = df_nist.apply(lambda x: x['InChIKey'] == x['K'], axis=1)\n", + "df_nist = df_nist[df_nist['matching_key']]\n", "\n", - "print(f\"Shape: {df_nist.shape}\")" + "print(f'Shape: {df_nist.shape}')" ] }, { @@ -924,7 +916,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_nist[\"ExactMolWeight\"] = df_nist[\"MOL\"].apply(Chem.Descriptors.ExactMolWt)" + "df_nist['ExactMolWeight'] = df_nist['MOL'].apply(Chem.Descriptors.ExactMolWt)" ] }, { @@ -951,23 +943,22 @@ "source": [ "MIN_PEAKS = 2\n", "MAX_PEAKS = 30\n", - "PRECURSOR_TYPES = [\"[M+H]+\"]\n", + "PRECURSOR_TYPES = ['[M+H]+']\n", "from modules.MOL.constants import ADDUCT_WEIGHTS\n", "\n", - "\n", - "df_nist = df_nist[df_nist[\"Num peaks\"] > MIN_PEAKS]\n", - "df_nist = df_nist[df_nist[\"Num peaks\"] < MAX_PEAKS]\n", - "df_nist[\"theoretical_precursor_mz\"] = df_nist[\"ExactMolWeight\"] + df_nist[\n", - " \"Precursor_type\"\n", + "df_nist = df_nist[df_nist['Num peaks'] > MIN_PEAKS]\n", + "df_nist = df_nist[df_nist['Num peaks'] < MAX_PEAKS]\n", + "df_nist['theoretical_precursor_mz'] = df_nist['ExactMolWeight'] + df_nist[\n", + " 'Precursor_type'\n", "].map(ADDUCT_WEIGHTS)\n", "df_nist = df_nist[\n", - " df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in PRECURSOR_TYPES)\n", + " df_nist['Precursor_type'].apply(lambda ptype: ptype in PRECURSOR_TYPES)\n", "]\n", - "df_nist[\"precursor_offset\"] = (\n", - " df_nist[\"PrecursorMZ\"] - df_nist[\"theoretical_precursor_mz\"]\n", + "df_nist['precursor_offset'] = (\n", + " df_nist['PrecursorMZ'] - df_nist['theoretical_precursor_mz']\n", ")\n", "\n", - "print(f\"Shape {df_nist.shape}\")" + "print(f'Shape {df_nist.shape}')" ] }, { @@ -997,29 +988,29 @@ ], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1.5]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1.5]}, sharey=False\n", ")\n", "for ax in axs:\n", - " ax.tick_params(\"x\", labelrotation=45)\n", + " ax.tick_params('x', labelrotation=45)\n", "\n", "sns.scatterplot(\n", " ax=axs[0],\n", " data=df_nist,\n", - " x=\"precursor_offset\",\n", - " y=\"PrecursorMZ\",\n", + " x='precursor_offset',\n", + " y='PrecursorMZ',\n", " palette=color_palette,\n", ")\n", "sns.histplot(\n", " ax=axs[1],\n", " data=df_nist,\n", - " x=\"precursor_offset\",\n", + " x='precursor_offset',\n", " color=color_palette[7],\n", " fill=True,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", ") # , order=list(range(0,200)))\n", - "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", - "axs[1].set_ylabel(\"\")\n", + "plt.rcParams['patch.force_edgecolor'] = True\n", + "plt.rcParams['axes.edgecolor'] = 'black'\n", + "axs[1].set_ylabel('')\n", "# axs[1].set_xlim([0, 100])\n", "# axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", "\n", @@ -1067,11 +1058,11 @@ "def align_CE(ce, precursor_mz):\n", " if type(ce) == float:\n", " return ce\n", - " if \"eV\" in ce:\n", - " ce = ce.replace(\"eV\", \"\")\n", + " if 'eV' in ce:\n", + " ce = ce.replace('eV', '')\n", " return float(ce)\n", - " elif \"%\" in ce:\n", - " nce = ce.split(\"%\")[0].strip().split(\" \")[-1]\n", + " elif '%' in ce:\n", + " nce = ce.split('%')[0].strip().split(' ')[-1]\n", " try:\n", " nce = float(nce)\n", " return NCE_to_eV(nce, precursor_mz)\n", @@ -1088,12 +1079,12 @@ "charge_factor = {1: 1, 2: 0.9, 3: 0.85, 4: 0.8, 5: 0.75}\n", "\n", "\n", - "df_nist[\"CE\"] = df_nist.apply(\n", - " lambda x: align_CE(x[\"Collision_energy\"], x[\"theoretical_precursor_mz\"]), axis=1\n", + "df_nist['CE'] = df_nist.apply(\n", + " lambda x: align_CE(x['Collision_energy'], x['theoretical_precursor_mz']), axis=1\n", ") # modules.MOL.collision_energy.align_CE)\n", - "df_nist[\"CE_type\"] = df_nist[\"CE\"].apply(type)\n", - "df_nist[\"CE_derived_from_NCE\"] = df_nist[\"Collision_energy\"].apply(\n", - " lambda x: \"%\" in str(x)\n", + "df_nist['CE_type'] = df_nist['CE'].apply(type)\n", + "df_nist['CE_derived_from_NCE'] = df_nist['Collision_energy'].apply(\n", + " lambda x: '%' in str(x)\n", ")\n", "# df_test = df_nist[df_nist[\"Collision_energy\"].apply(lambda x: \"%\" in str(x))][\"Collision_energy\"]\n", "# df_test = df_test.apply(lambda x: x.split('%')[0].strip().split(' ')[-1])\n", @@ -1105,17 +1096,17 @@ "\n", "\n", "print(\n", - " \"Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)\"\n", + " 'Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)'\n", ")\n", - "print(df_nist[\"CE_type\"].value_counts())\n", + "print(df_nist['CE_type'].value_counts())\n", "\n", - "print(\"Removing all but absolute values\")\n", - "df_nist = df_nist[df_nist[\"CE_type\"] == float]\n", - "df_nist = df_nist[~df_nist[\"CE\"].isnull()]\n", + "print('Removing all but absolute values')\n", + "df_nist = df_nist[df_nist['CE_type'] == float]\n", + "df_nist = df_nist[~df_nist['CE'].isnull()]\n", "# len(df_nist['CE'].unique())\n", "\n", "print(\n", - " f\"Detected {len(df_nist['CE'].unique())} unique collision energies in range from {np.min(df_nist['CE'])} to {max(df_nist['CE'])} eV\"\n", + " f'Detected {len(df_nist[\"CE\"].unique())} unique collision energies in range from {np.min(df_nist[\"CE\"])} to {max(df_nist[\"CE\"])} eV'\n", ")" ] }, @@ -1153,19 +1144,19 @@ "sns.histplot(\n", " ax=ax,\n", " data=df_nist,\n", - " x=\"CE\",\n", - " hue=\"CE_derived_from_NCE\",\n", + " x='CE',\n", + " hue='CE_derived_from_NCE',\n", " palette=[color_palette[4], color_palette[2]],\n", - " multiple=\"stack\",\n", + " multiple='stack',\n", " fill=True,\n", " binwidth=2,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " binrange=[0, 200],\n", ") # , order=list(range(0,200)))\n", - "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", + "plt.rcParams['patch.force_edgecolor'] = True\n", + "plt.rcParams['axes.edgecolor'] = 'black'\n", "plt.show()\n", - "print(f\"{df_nist.shape[0]} spectra remaining with aligned absolute collision energies\")" + "print(f'{df_nist.shape[0]} spectra remaining with aligned absolute collision energies')" ] }, { @@ -1185,19 +1176,19 @@ "outputs": [], "source": [ "%%capture\n", - "from modules.MOL.Metabolite import Metabolite\n", "from modules.MOL.constants import PPM\n", + "from modules.MOL.Metabolite import Metabolite\n", "\n", "TOLERANCE = 200 * PPM\n", "\n", "\n", - "df_nist[\"Metabolite\"] = df_nist[\"SMILES\"].apply(Metabolite)\n", - "df_nist[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - "df_nist[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes())\n", - "df_nist[\"Metabolite\"].apply(lambda x: x.fragment_MOL())\n", + "df_nist['Metabolite'] = df_nist['SMILES'].apply(Metabolite)\n", + "df_nist['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_nist['Metabolite'].apply(lambda x: x.compute_graph_attributes())\n", + "df_nist['Metabolite'].apply(lambda x: x.fragment_MOL())\n", "df_nist.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=TOLERANCE\n", " ),\n", " axis=1,\n", ")" @@ -1209,17 +1200,16 @@ "metadata": {}, "outputs": [], "source": [ + "from modules.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", "from modules.MOL.mol_graph import (\n", - " mol_to_graph,\n", + " draw_graph,\n", " get_adjacency_matrix,\n", " get_degree_matrix,\n", " get_edges,\n", " get_identity_matrix,\n", - " draw_graph,\n", + " mol_to_graph,\n", ")\n", "\n", - "from modules.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", - "\n", "node_encoder = AtomFeatureEncoder()" ] }, @@ -1233,14 +1223,14 @@ "\n", "\n", "def add_dataframe_features(df):\n", - " df[\"graph\"] = df[\"MOL\"].apply(mol_to_graph)\n", + " df['graph'] = df['MOL'].apply(mol_to_graph)\n", " # df['features'] = df['graph'].apply(node_encoder.encode)\n", - " df[\"A\"] = df[\"graph\"].apply(get_adjacency_matrix)\n", - " df[\"Atilde\"] = df[\"A\"].apply(lambda x: x + np.eye(N=x.shape[0]))\n", - " df[\"Id\"] = df[\"A\"].apply(get_identity_matrix)\n", - " df[\"deg\"] = df[\"A\"].apply(get_degree_matrix)\n", - " df[\"is_aromatic\"] = df[\"graph\"].apply(\n", - " lambda x: np.array([[x.nodes[atom][\"is_aromatic\"] for atom in x.nodes()]]).T\n", + " df['A'] = df['graph'].apply(get_adjacency_matrix)\n", + " df['Atilde'] = df['A'].apply(lambda x: x + np.eye(N=x.shape[0]))\n", + " df['Id'] = df['A'].apply(get_identity_matrix)\n", + " df['deg'] = df['A'].apply(get_degree_matrix)\n", + " df['is_aromatic'] = df['graph'].apply(\n", + " lambda x: np.array([[x.nodes[atom]['is_aromatic'] for atom in x.nodes()]]).T\n", " )\n", "\n", " # Extras\n", @@ -1287,21 +1277,21 @@ "x = df_nist.loc[EXAMPLE_ID]\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1]}, sharey=False\n", ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].set_title(x['Name'] + ' structure:\\n' + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", + "axs[0].axis('off')\n", "\n", - "g_img = draw_graph(x[\"graph\"], ax=axs[1])\n", - "print(x[\"peaks\"])" + "g_img = draw_graph(x['graph'], ax=axs[1])\n", + "print(x['peaks'])" ] }, { @@ -1357,8 +1347,8 @@ ], "source": [ "d = df_nist.iloc[0]\n", - "nx.convert_matrix.to_numpy_matrix(d[\"graph\"])\n", - "d[\"A\"]" + "nx.convert_matrix.to_numpy_matrix(d['graph'])\n", + "d['A']" ] }, { @@ -1385,7 +1375,7 @@ "\n", "x = df_nist.loc[EXAMPLE_ID]\n", "\n", - "FT = x[\"Metabolite\"].fragmentation_tree\n", + "FT = x['Metabolite'].fragmentation_tree\n", "\n", "FT.get_fragment(3)" ] @@ -1442,15 +1432,14 @@ "\n", "importlib.reload(modules.IO.fraggraphReader)\n", "importlib.reload(modules.MOL.FragmentationTree)\n", - "from modules.MOL.FragmentationTree import FragmentationTree\n", "\n", "\n", "f = fraggraphReader.parser_fraggraph_gen(\n", - " library_directory + \"examples/CNCCCN1c2ccccc2CCc2ccccc21_fraggraph.txt\"\n", + " library_directory + 'examples/CNCCCN1c2ccccc2CCc2ccccc21_fraggraph.txt'\n", ")\n", "x = df_nist.loc[EXAMPLE_ID]\n", "\n", - "match_fragment_lists(x[\"peaks\"][\"mz\"], f[\"fragments\"][\"mass\"])" + "match_fragment_lists(x['peaks']['mz'], f['fragments']['mass'])" ] }, { @@ -1509,27 +1498,27 @@ "source": [ "x = df_nist.loc[EXAMPLE_ID]\n", "\n", - "FT = x[\"Metabolite\"].fragmentation_tree\n", + "FT = x['Metabolite'].fragmentation_tree\n", "# frag.build_fragmentation_tree_by_rotatable_bond_breaks()\n", "print(FT)\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "\n", - "img = Chem.Draw.MolToImage(x[\"MOL\"], ax=axs[0])\n", + "img = Chem.Draw.MolToImage(x['MOL'], ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + x[\"SMILES\"])\n", + "axs[0].set_title(x['Name'] + ' structure:\\n' + x['SMILES'])\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", - "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n", + "axs[0].axis('off')\n", + "sv.plot_spectrum(title=x['Name'] + ' MS/MS spectrum', spectrum=x, ax=axs[1])\n", "\n", - "print(\"Matching peaks to fragments\")\n", - "print(x[\"Metabolite\"].peak_matches)" + "print('Matching peaks to fragments')\n", + "print(x['Metabolite'].peak_matches)" ] }, { @@ -1565,20 +1554,20 @@ } ], "source": [ - "df_nist[\"peak_matches\"] = df_nist[\"Metabolite\"].apply(\n", - " lambda x: getattr(x, \"peak_matches\")\n", + "df_nist['peak_matches'] = df_nist['Metabolite'].apply(\n", + " lambda x: getattr(x, 'peak_matches')\n", ")\n", - "df_nist[\"num_peaks_matched\"] = df_nist[\"peak_matches\"].apply(len)\n", + "df_nist['num_peaks_matched'] = df_nist['peak_matches'].apply(len)\n", "\n", "\n", "def get_match_stats(matches):\n", " num_unique, num_conflicts, mode_count = (\n", " 0,\n", " 0,\n", - " {\"[M+H]+\": 0, \"[M-H]+\": 0, \"[M-3H]+\": 0},\n", + " {'[M+H]+': 0, '[M-H]+': 0, '[M-3H]+': 0},\n", " )\n", " for mz, match_data in matches.items():\n", - " candidates = match_data[\"fragments\"]\n", + " candidates = match_data['fragments']\n", " if len(candidates) == 1:\n", " num_unique += 1\n", " elif len(candidates) > 1:\n", @@ -1591,20 +1580,20 @@ "d = df_nist.loc[EXAMPLE_ID]\n", "\n", "\n", - "df_nist[\"match_stats\"] = df_nist[\"peak_matches\"].apply(lambda x: get_match_stats(x))\n", - "df_nist[\"num_unique_peaks_matched\"] = df_nist.apply(\n", - " lambda x: x[\"match_stats\"][0], axis=1\n", + "df_nist['match_stats'] = df_nist['peak_matches'].apply(lambda x: get_match_stats(x))\n", + "df_nist['num_unique_peaks_matched'] = df_nist.apply(\n", + " lambda x: x['match_stats'][0], axis=1\n", ")\n", - "df_nist[\"num_conflicts_in_peak_matching\"] = df_nist.apply(\n", - " lambda x: x[\"match_stats\"][1], axis=1\n", + "df_nist['num_conflicts_in_peak_matching'] = df_nist.apply(\n", + " lambda x: x['match_stats'][1], axis=1\n", ")\n", - "df_nist[\"match_mode_counts\"] = df_nist.apply(lambda x: x[\"match_stats\"][2], axis=1)\n", - "u = df_nist[\"num_unique_peaks_matched\"].sum()\n", - "s = df_nist[\"num_conflicts_in_peak_matching\"].sum()\n", + "df_nist['match_mode_counts'] = df_nist.apply(lambda x: x['match_stats'][2], axis=1)\n", + "u = df_nist['num_unique_peaks_matched'].sum()\n", + "s = df_nist['num_conflicts_in_peak_matching'].sum()\n", "print(\n", - " f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))\"\n", + " f'Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))'\n", ")\n", - "print(f\"Total number of conflicting peak to fragment matches: {s}\")\n", + "print(f'Total number of conflicting peak to fragment matches: {s}')\n", "\n", "df_nist.shape" ] @@ -1630,34 +1619,34 @@ "source": [ "fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", - "fig.suptitle(f\"Identified peaks with fragment offset\")\n", + "fig.suptitle('Identified peaks with fragment offset')\n", "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", "sns.histplot(\n", " ax=axs[0],\n", " data=df_nist,\n", - " x=\"num_peaks_matched\",\n", + " x='num_peaks_matched',\n", " color=color_palette[0],\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " bins=range(0, 20, 1),\n", ")\n", "# axs[0].set_ylim(-0.5, 10)\n", - "axs[0].set_ylabel(\"peaks identified\")\n", + "axs[0].set_ylabel('peaks identified')\n", "\n", "\n", "sns.boxplot(\n", - " ax=axs[1], data=df_nist, y=\"num_unique_peaks_matched\", color=color_palette[1]\n", + " ax=axs[1], data=df_nist, y='num_unique_peaks_matched', color=color_palette[1]\n", ")\n", "axs[1].set_ylim(-0.5, 15)\n", - "axs[1].set_xlabel(\"unique matches\")\n", - "axs[1].set_ylabel(\"\")\n", + "axs[1].set_xlabel('unique matches')\n", + "axs[1].set_ylabel('')\n", "\n", "\n", "sns.boxplot(\n", - " ax=axs[2], data=df_nist, y=\"num_conflicts_in_peak_matching\", color=color_palette[3]\n", + " ax=axs[2], data=df_nist, y='num_conflicts_in_peak_matching', color=color_palette[3]\n", ")\n", "axs[2].set_ylim(-0.5, 15)\n", - "axs[2].set_xlabel(\"conflicts\")\n", - "axs[2].set_ylabel(\"\")\n", + "axs[2].set_xlabel('conflicts')\n", + "axs[2].set_ylabel('')\n", "\n", "plt.show()" ] @@ -1683,7 +1672,7 @@ "source": [ "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), sharey=False)\n", "\n", - "mode_counts = {\"[M+H]+\": 0, \"[M-H]+\": 0, \"[M-3H]+\": 0}\n", + "mode_counts = {'[M+H]+': 0, '[M-H]+': 0, '[M-3H]+': 0}\n", "\n", "\n", "def update_mode_counts(m):\n", @@ -1691,21 +1680,21 @@ " mode_counts[mode] += m[mode]\n", "\n", "\n", - "df_nist[\"match_mode_counts\"].apply(update_mode_counts)\n", + "df_nist['match_mode_counts'].apply(update_mode_counts)\n", "\n", "sns.barplot(\n", " ax=axs[0],\n", " x=list(mode_counts.keys()),\n", " y=[mode_counts[k] for k in mode_counts.keys()],\n", " palette=color_palette,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " linewidth=1.5,\n", ")\n", "axs[1].pie(\n", " [mode_counts[k] for k in mode_counts.keys()],\n", " labels=list(mode_counts.keys()),\n", " colors=color_palette,\n", - " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + " wedgeprops={'edgecolor': 'black', 'linewidth': 1.5},\n", ")\n", "\n", "plt.show()" @@ -1732,8 +1721,8 @@ "source": [ "for i in range(0, 6):\n", " print(\n", - " f\"Minimum {i} unique peaks identified (including precursors): \",\n", - " (df_nist[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " f'Minimum {i} unique peaks identified (including precursors): ',\n", + " (df_nist['num_unique_peaks_matched'] >= i).sum(),\n", " )" ] }, @@ -1752,61 +1741,61 @@ "outputs": [], "source": [ "save_df = True\n", - "name = \"nist_msms_filtered\"\n", - "date = \"06_2023\"\n", + "name = 'nist_msms_filtered'\n", + "date = '06_2023'\n", "min_peaks = 5\n", "\n", "if save_df:\n", " key_columns = [\n", - " \"Name\",\n", - " \"Synon\",\n", - " \"Notes\",\n", - " \"Precursor_type\",\n", - " \"Spectrum_type\",\n", - " \"PrecursorMZ\",\n", - " \"Instrument_type\",\n", - " \"Instrument\",\n", - " \"Sample_inlet\",\n", - " \"Ionization\",\n", - " \"Collision_energy\",\n", - " \"Ion_mode\",\n", - " \"Special_fragmentation\",\n", - " \"InChIKey\",\n", - " \"Formula\",\n", - " \"MW\",\n", - " \"ExactMass\",\n", - " \"CASNO\",\n", - " \"NISTNO\",\n", - " \"ID\",\n", - " \"Comment\",\n", - " \"Num peaks\",\n", - " \"peaks\",\n", - " \"Link\",\n", - " \"Related_CAS#\",\n", - " \"Collision_gas\",\n", - " \"Pressure\",\n", - " \"In-source_voltage\",\n", - " \"msN_pathway\",\n", - " \"MOL\",\n", - " \"SMILES\",\n", - " \"InChI\",\n", - " \"K\",\n", - " \"ExactMolWeight\",\n", - " \"matching_key\",\n", - " \"theoretical_precursor_mz\",\n", - " \"precursor_offset\",\n", - " \"CE\",\n", - " \"CE_type\",\n", - " \"peak_matches\",\n", - " \"num_peaks_matched\",\n", - " \"match_stats\",\n", - " \"num_unique_peaks_matched\",\n", - " \"num_conflicts_in_peak_matching\",\n", - " \"match_mode_counts\",\n", + " 'Name',\n", + " 'Synon',\n", + " 'Notes',\n", + " 'Precursor_type',\n", + " 'Spectrum_type',\n", + " 'PrecursorMZ',\n", + " 'Instrument_type',\n", + " 'Instrument',\n", + " 'Sample_inlet',\n", + " 'Ionization',\n", + " 'Collision_energy',\n", + " 'Ion_mode',\n", + " 'Special_fragmentation',\n", + " 'InChIKey',\n", + " 'Formula',\n", + " 'MW',\n", + " 'ExactMass',\n", + " 'CASNO',\n", + " 'NISTNO',\n", + " 'ID',\n", + " 'Comment',\n", + " 'Num peaks',\n", + " 'peaks',\n", + " 'Link',\n", + " 'Related_CAS#',\n", + " 'Collision_gas',\n", + " 'Pressure',\n", + " 'In-source_voltage',\n", + " 'msN_pathway',\n", + " 'MOL',\n", + " 'SMILES',\n", + " 'InChI',\n", + " 'K',\n", + " 'ExactMolWeight',\n", + " 'matching_key',\n", + " 'theoretical_precursor_mz',\n", + " 'precursor_offset',\n", + " 'CE',\n", + " 'CE_type',\n", + " 'peak_matches',\n", + " 'num_peaks_matched',\n", + " 'match_stats',\n", + " 'num_unique_peaks_matched',\n", + " 'num_conflicts_in_peak_matching',\n", + " 'match_mode_counts',\n", " ]\n", - " file = library_directory + name + \"_min\" + str(min_peaks) + \"_\" + date + \".csv\"\n", - " print(\"saving to \", file)\n", - " df_nist[df_nist[\"num_unique_peaks_matched\"] >= min_peaks][key_columns].to_csv(file)\n", + " file = library_directory + name + '_min' + str(min_peaks) + '_' + date + '.csv'\n", + " print('saving to ', file)\n", + " df_nist[df_nist['num_unique_peaks_matched'] >= min_peaks][key_columns].to_csv(file)\n", "\n", " # df_nist[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" ] @@ -1839,10 +1828,10 @@ "outputs": [], "source": [ "import torch\n", - "from torch.utils.data import Dataset, DataLoader\n", - "from torchmetrics import Accuracy, MetricTracker\n", - "from modules.GNN.MLPEdgeClassifier import MLPEdgeClassifier\n", "from modules.GNN.GNNModels import GCNNodeClassifier\n", + "from modules.GNN.MLPEdgeClassifier import MLPEdgeClassifier\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from torchmetrics import Accuracy, MetricTracker\n", "# importlib.reload(modules.GNN.GNNModels)\n", "# importlib.reload(modules.GNN.GNNLayers)" ] @@ -1855,7 +1844,7 @@ "source": [ "def train(model, dataloader_training, optimizer, loss_fn, tracker, epochs=10):\n", " for e in range(epochs):\n", - " print(f\"Epoch {e + 1}/{epochs}\")\n", + " print(f'Epoch {e + 1}/{epochs}')\n", " training_loss = 0\n", " tracker.increment()\n", " for batch_id, (X, y) in enumerate(dataloader_training):\n", @@ -1875,10 +1864,10 @@ "\n", " # Record loss\n", " training_loss += loss.item()\n", - " print(f\"Avg. training loss {training_loss / (batch_id + 1):>.3f}\", end=\"\\r\")\n", + " print(f'Avg. training loss {training_loss / (batch_id + 1):>.3f}', end='\\r')\n", "\n", - " print(\"\")\n", - " print(f\"Accuracy: {tracker.compute():>.4f}\")\n", + " print('')\n", + " print(f'Accuracy: {tracker.compute():>.4f}')\n", " training_loss /= len(dataloader_training)\n", "\n", " return training_loss\n", @@ -1902,7 +1891,7 @@ "\n", " val_accuracy = tracker.compute()\n", " validation_loss /= len(dataloader_val)\n", - " print(f\" Validation Accuracy: {val_accuracy:>.3f} (Loss: {validation_loss:>.3f})\")\n", + " print(f' Validation Accuracy: {val_accuracy:>.3f} (Loss: {validation_loss:>.3f})')\n", "\n", " return val_accuracy\n", "\n", @@ -1925,7 +1914,7 @@ " )\n", "\n", " for e in range(epochs):\n", - " print(f\"Epoch {e + 1}/{epochs}\")\n", + " print(f'Epoch {e + 1}/{epochs}')\n", " training_loss = 0\n", " tracker.increment()\n", " for batch_id, (X, A, y) in enumerate(dataloader_training):\n", @@ -1949,8 +1938,8 @@ " training_loss += loss.item()\n", " if batch_id % 100 == 0:\n", " print(\n", - " f\" Avg. training loss {training_loss / (batch_id + 1):>.4f}\",\n", - " end=\"\\r\",\n", + " f' Avg. training loss {training_loss / (batch_id + 1):>.4f}',\n", + " end='\\r',\n", " )\n", " if batch_id == 2000:\n", " break\n", @@ -1958,7 +1947,7 @@ " accuracy = tracker.compute()\n", " acc.append(accuracy)\n", " training_loss /= len(dataloader_training)\n", - " print(f\" Training Accuracy: {accuracy:>.3f} (Loss: {training_loss:>.3f})\")\n", + " print(f' Training Accuracy: {accuracy:>.3f} (Loss: {training_loss:>.3f})')\n", " validate_gnn(model, dataloader_val, loss_fn, tracker)\n", "\n", " return acc" @@ -1973,8 +1962,8 @@ "class AtomAromaticityData(Dataset):\n", " def __init__(self) -> None:\n", " # super().__init__()\n", - " self.X = np.concatenate(df_nist[\"features\"].values, dtype=\"float32\")\n", - " self.y = np.concatenate(df_nist[\"is_aromatic\"].values * 1, dtype=\"float32\")\n", + " self.X = np.concatenate(df_nist['features'].values, dtype='float32')\n", + " self.y = np.concatenate(df_nist['is_aromatic'].values * 1, dtype='float32')\n", "\n", " # ADD label to input features\n", " # self.X = [np.append(self.X[i], self.y[i]) for i in range(len(self))]\n", @@ -2061,17 +2050,17 @@ " def __init__(self) -> None:\n", "\n", " self.X = (\n", - " df_nist[\"features\"]\n", + " df_nist['features']\n", " .apply(lambda x: torch.tensor(x, dtype=torch.float32))\n", " .values\n", " ) # [:1000]\n", " self.A = (\n", - " df_nist[\"Atilde\"]\n", + " df_nist['Atilde']\n", " .apply(lambda x: torch.tensor(x, dtype=torch.float32))\n", " .values\n", " )\n", " self.y = (\n", - " df_nist[\"is_aromatic\"]\n", + " df_nist['is_aromatic']\n", " .apply(lambda x: torch.tensor(x, dtype=torch.float32))\n", " .values\n", " * 1\n", @@ -2580,10 +2569,10 @@ " acc, mean_loss = sklearn.metrics.accuracy_score(y_true, y_hat), np.mean(losses)\n", " if verbose:\n", " print(\n", - " \"Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f\"\n", + " 'Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f'\n", " % (acc, mean_loss, correct_mol / len(data))\n", " )\n", - " print(f\"Torch MetricTracker Accuracy: {tracker.compute():.3f}\")\n", + " print(f'Torch MetricTracker Accuracy: {tracker.compute():.3f}')\n", " return acc, mean_loss, correct_mol / len(data)\n", "\n", "\n", @@ -2602,7 +2591,7 @@ " DataLoader(validation_data, batch_size=1),\n", " )\n", " for epoch in range(1, 6):\n", - " print(\"Epoch %s\" % epoch)\n", + " print('Epoch %s' % epoch)\n", " training_loss = 0\n", " training_loss_torch = 0\n", " tracker.increment()\n", @@ -2638,8 +2627,8 @@ "\n", " if batch_id % 100 == 0:\n", " print(\n", - " f\" Avg. training loss {training_loss / (batch_id + 1):>.4f} (torch loss: {training_loss_torch / (batch_id + 1)}\",\n", - " end=\"\\r\",\n", + " f' Avg. training loss {training_loss / (batch_id + 1):>.4f} (torch loss: {training_loss_torch / (batch_id + 1)}',\n", + " end='\\r',\n", " )\n", "\n", " # On epoch end: Evaluation\n", @@ -2647,7 +2636,7 @@ " acc.append(accuracy)\n", " training_loss /= len(dataloader_training)\n", " print(\n", - " f\" Training Accuracy: {accuracy:>.3f} (Loss: {training_loss:>.3f})\",\n", + " f' Training Accuracy: {accuracy:>.3f} (Loss: {training_loss:>.3f})',\n", " flush=True,\n", " )\n", " # validate_gnn(gnn_model, dataloader_val, loss_fn, tracker)\n", @@ -2656,7 +2645,7 @@ " validate_tf_model(\n", " data=validation_data,\n", " model=model,\n", - " y_label=\"is_aromatic\",\n", + " y_label='is_aromatic',\n", " verbose=True,\n", " tracker=tracker,\n", " )\n", @@ -2702,20 +2691,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#\n", - "#\n", - "#\n", - "#\n", - "#\n", - "#\n", - "#\n", - "#\n", - "#\n", - "#\n", - "#\n", - "#" - ] + "source": [] }, { "cell_type": "code", @@ -2751,7 +2727,7 @@ " acc, mean_loss = sklearn.metrics.accuracy_score(y_true, y_hat), np.mean(losses)\n", " if verbose:\n", " print(\n", - " \"Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f\"\n", + " 'Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f'\n", " % (acc, mean_loss, correct_mol / data.shape[0])\n", " )\n", "\n", @@ -2785,9 +2761,9 @@ " optimizer.apply_gradients(zip(gradients, variables))\n", "\n", " # Validate loss/acc\n", - " print(\"Epoch %s\" % epoch)\n", + " print('Epoch %s' % epoch)\n", " validate_tf_model_old(\n", - " data=data_val, model=gnn_model, y_label=\"is_aromatic\", verbose=True\n", + " data=data_val, model=gnn_model, y_label='is_aromatic', verbose=True\n", " )\n", "\n", " return\n", @@ -3088,9 +3064,9 @@ " edge_connected_node_features = df_nist.apply(\n", " lambda x: self.concatenate_node_features(x), axis=1\n", " )\n", - " self.X = np.concatenate(edge_connected_node_features.values, dtype=\"float32\")\n", + " self.X = np.concatenate(edge_connected_node_features.values, dtype='float32')\n", " self.y = np.concatenate(\n", - " df_nist[\"edges_is_aromatic\"].values * 1, dtype=\"float32\"\n", + " df_nist['edges_is_aromatic'].values * 1, dtype='float32'\n", " )\n", "\n", " # ADD label to input features\n", @@ -3098,7 +3074,7 @@ "\n", " def concatenate_node_features(self, d):\n", " return np.concatenate(\n", - " [d[\"AL\"] @ d[\"features\"], d[\"AR\"] @ d[\"features\"]], axis=1\n", + " [d['AL'] @ d['features'], d['AR'] @ d['features']], axis=1\n", " )\n", "\n", " def __len__(self):\n", @@ -3194,15 +3170,13 @@ "metadata": {}, "outputs": [], "source": [ - "from modules.MOL.FragmentationTree import FragmentationTree\n", - "\n", "importlib.reload(modules.MOL.FragmentationTree)\n", "\n", - "FT = FragmentationTree(x[\"MOL\"])\n", + "FT = FragmentationTree(x['MOL'])\n", "# frag.build_fragmentation_tree_by_rotatable_bond_breaks()\n", - "FT.build_fragmentation_tree(x[\"MOL\"], x.edges_idx, depth=1)\n", + "FT.build_fragmentation_tree(x['MOL'], x.edges_idx, depth=1)\n", "\n", - "x_matches = FT.match_peak_list(df_nist.loc[EXAMPLE_ID][\"peaks\"][\"mz\"])\n", + "x_matches = FT.match_peak_list(df_nist.loc[EXAMPLE_ID]['peaks']['mz'])\n", "\n", "print(len(x_matches.keys()))\n", "for x in x_matches.keys():\n", @@ -3216,35 +3190,35 @@ "outputs": [], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "\n", - "img = Chem.Draw.MolToImage(x[\"MOL\"], ax=axs[0])\n", + "img = Chem.Draw.MolToImage(x['MOL'], ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x[\"MOL\"]))\n", + "axs[0].set_title(x['Name'] + ' structure:\\n' + Chem.MolToSmiles(x['MOL']))\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", - "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n", - "axs[1].text(100, 0.20, \"GeeksforGeeks\", style=\"italic\", fontsize=30, color=\"green\")\n", + "axs[0].axis('off')\n", + "sv.plot_spectrum(title=x['Name'] + ' MS/MS spectrum', spectrum=x, ax=axs[1])\n", + "axs[1].text(100, 0.20, 'GeeksforGeeks', style='italic', fontsize=30, color='green')\n", "\n", "\n", - "print(\"m/z\", x[\"peaks\"][\"mz\"])\n", - "print(\"Int\", x[\"peaks\"][\"intensity\"])\n", - "print(\"\\nCreate Fragmentation Tree (depth = 1)\\n\")\n", + "print('m/z', x['peaks']['mz'])\n", + "print('Int', x['peaks']['intensity'])\n", + "print('\\nCreate Fragmentation Tree (depth = 1)\\n')\n", "print(FT)\n", "# print(t.size(level=0), t.size(level=1))# t.size(level=2))\n", "\n", - "mols = [x[\"MOL\"], FT.get_fragment(3), FT.get_fragment(10), FT.get_fragment(7)]\n", + "mols = [x['MOL'], FT.get_fragment(3), FT.get_fragment(10), FT.get_fragment(7)]\n", "\n", "Chem.Draw.MolsToGridImage(\n", " mols,\n", " molsPerRow=4,\n", " useSVG=True,\n", - " legends=[f\" mol ({Chem.Descriptors.ExactMolWt(m)})\" for m in mols],\n", + " legends=[f' mol ({Chem.Descriptors.ExactMolWt(m)})' for m in mols],\n", ")" ] }, @@ -5322,11 +5296,11 @@ "#\n", "\n", "Chem.Draw.MolsToGridImage(\n", - " [df_nist.iloc[i][\"MOL\"] for i in range(50)],\n", + " [df_nist.iloc[i]['MOL'] for i in range(50)],\n", " molsPerRow=4,\n", " useSVG=True,\n", " legends=[str(i) for i in range(50)],\n", - ") #" + ")" ] }, { @@ -5491,7 +5465,7 @@ "source": [ "n = df_nist.iloc[40].Name\n", "print(n)\n", - "df_nist[df_nist.Name == n][[\"CE\", \"Num peaks\", \"peaks\"]]" + "df_nist[df_nist.Name == n][['CE', 'Num peaks', 'peaks']]" ] }, { @@ -5508,43 +5482,43 @@ "# print(x)\n", "\n", "FT = FragmentationTree()\n", - "FT.build_fragmentation_tree_by_single_edge_breaks(x[\"MOL\"], x.edges_idx, depth=1)\n", + "FT.build_fragmentation_tree_by_single_edge_breaks(x['MOL'], x.edges_idx, depth=1)\n", "\n", "# print(Chem.Descriptors.ExactMolWt(x[\"MOL\"]))\n", "# print(Chem.Descriptors.ExactMolWt(t.get_node(6).data))\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "\n", - "img = Chem.Draw.MolToImage(x[\"MOL\"], ax=axs[0])\n", + "img = Chem.Draw.MolToImage(x['MOL'], ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x[\"MOL\"]))\n", + "axs[0].set_title(x['Name'] + ' structure:\\n' + Chem.MolToSmiles(x['MOL']))\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", - "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n", - "axs[1].text(100, 0.20, \"GeeksforGeeks\", style=\"italic\", fontsize=30, color=\"green\")\n", + "axs[0].axis('off')\n", + "sv.plot_spectrum(title=x['Name'] + ' MS/MS spectrum', spectrum=x, ax=axs[1])\n", + "axs[1].text(100, 0.20, 'GeeksforGeeks', style='italic', fontsize=30, color='green')\n", "\n", "\n", - "print(\"m/z\", x[\"peaks\"][\"mz\"])\n", - "print(\"Int\", x[\"peaks\"][\"intensity\"])\n", - "print(\"\\nCreate Fragmentation Tree (depth = 1)\\n\")\n", + "print('m/z', x['peaks']['mz'])\n", + "print('Int', x['peaks']['intensity'])\n", + "print('\\nCreate Fragmentation Tree (depth = 1)\\n')\n", "t.show(idhidden=False)\n", "# print(t.size(level=0), t.size(level=1))# t.size(level=2))\n", "\n", "Chem.Draw.MolsToGridImage(\n", - " [x[\"MOL\"], t.get_node(1).data, t.get_node(4).data, t.get_node(5).data],\n", + " [x['MOL'], t.get_node(1).data, t.get_node(4).data, t.get_node(5).data],\n", " molsPerRow=4,\n", " useSVG=True,\n", " legends=[\n", - " \"intact\",\n", - " f\"frag ({t.get_node(1).tag:.03f})\",\n", - " f\"frag ({t.get_node(4).tag:.03f})\",\n", - " f\"frag ({t.get_node(5).tag:.03f})\",\n", + " 'intact',\n", + " f'frag ({t.get_node(1).tag:.03f})',\n", + " f'frag ({t.get_node(4).tag:.03f})',\n", + " f'frag ({t.get_node(5).tag:.03f})',\n", " ],\n", ")" ] @@ -5583,10 +5557,10 @@ "for off in offsets:\n", " print(\n", " off,\n", - " \":\",\n", - " np.mean(D[str(off)][\"peaks\"]),\n", - " np.mean(D[str(off)][\"unique\"]),\n", - " np.mean(D[str(off)][\"percentage\"]),\n", + " ':',\n", + " np.mean(D[str(off)]['peaks']),\n", + " np.mean(D[str(off)]['unique']),\n", + " np.mean(D[str(off)]['percentage']),\n", " )" ] }, @@ -5642,17 +5616,17 @@ "for off in offsets:\n", " fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", - " fig.suptitle(f\"Identified peaks with fragment offset: {str(off)}\")\n", + " fig.suptitle(f'Identified peaks with fragment offset: {str(off)}')\n", " # plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", - " sns.boxplot(ax=axs[0], y=D[str(off)][\"peaks\"], color=color_palette[0])\n", + " sns.boxplot(ax=axs[0], y=D[str(off)]['peaks'], color=color_palette[0])\n", " axs[0].set_ylim(-0.5, 10)\n", - " axs[0].set_ylabel(\"peaks identified\")\n", - " sns.boxplot(ax=axs[1], y=D[str(off)][\"unique\"], color=color_palette[1])\n", + " axs[0].set_ylabel('peaks identified')\n", + " sns.boxplot(ax=axs[1], y=D[str(off)]['unique'], color=color_palette[1])\n", " axs[1].set_ylim(-0.5, 10)\n", - " axs[1].set_ylabel(\"peaks uniquely identified\")\n", - " sns.boxplot(ax=axs[2], y=D[str(off)][\"percentage\"], color=color_palette[2])\n", + " axs[1].set_ylabel('peaks uniquely identified')\n", + " sns.boxplot(ax=axs[2], y=D[str(off)]['percentage'], color=color_palette[2])\n", " axs[2].set_ylim(-0.05, 1.0)\n", - " axs[2].set_ylabel(\"peak percentags (by number)\")\n", + " axs[2].set_ylabel('peak percentags (by number)')\n", " plt.show()" ] }, @@ -5664,9 +5638,9 @@ "source": [ "print(zigzagerrorhack)\n", "frame = PandasTools.LoadSDF(\n", - " library_directory + library_name + \".SDF\",\n", - " smilesName=\"SMILES\",\n", - " molColName=\"Molecule\",\n", + " library_directory + library_name + '.SDF',\n", + " smilesName='SMILES',\n", + " molColName='Molecule',\n", " includeFingerprints=True,\n", ")\n", "\n", @@ -5680,14 +5654,14 @@ "outputs": [], "source": [ "# structure library\n", - "structure_supplier = Chem.SDMolSupplier(library_directory + library_name + \".SDF\")\n", + "structure_supplier = Chem.SDMolSupplier(library_directory + library_name + '.SDF')\n", "print(structure_supplier)\n", "nist_mols = [structure_supplier[i] for i in range(0, 100)] # TODO use all\n", "df_nist = pd.DataFrame()\n", - "df_nist[\"mol\"] = nist_mols\n", + "df_nist['mol'] = nist_mols\n", "df_nist.dropna(inplace=True)\n", "\n", - "df_nist[\"smiles\"] = df_nist[\"mol\"].apply(lambda x: Chem.MolToSmiles(x))\n", + "df_nist['smiles'] = df_nist['mol'].apply(lambda x: Chem.MolToSmiles(x))\n", "\n", "df_nist.head()" ] @@ -5739,7 +5713,7 @@ "source": [ "import networkx as nx\n", "\n", - "color_map = {\"C\": \"gray\", \"O\": \"red\", \"N\": \"blue\"}\n", + "color_map = {'C': 'gray', 'O': 'red', 'N': 'blue'}\n", "\n", "\n", "def mol_to_nx(mol):\n", @@ -5749,7 +5723,7 @@ " color = (\n", " color_map[atom.GetSymbol()]\n", " if atom.GetSymbol() in color_map.keys()\n", - " else \"black\"\n", + " else 'black'\n", " )\n", " G.add_node(\n", " atom.GetIdx(),\n", @@ -5777,17 +5751,17 @@ " nx.draw(\n", " G,\n", " pos=pos,\n", - " labels=nx.get_node_attributes(G, \"atom_symbol\"),\n", + " labels=nx.get_node_attributes(G, 'atom_symbol'),\n", " with_labels=True,\n", - " node_color=list(nx.get_node_attributes(G, \"color\").values()),\n", + " node_color=list(nx.get_node_attributes(G, 'color').values()),\n", " node_size=800,\n", " )\n", " if edge_labels:\n", " nx.draw_networkx_edge_labels(\n", " G,\n", " pos,\n", - " edge_labels=dict([((n1, n2), f\"({n1}, {n2})\") for n1, n2 in G.edges]),\n", - " font_color=\"red\",\n", + " edge_labels=dict([((n1, n2), f'({n1}, {n2})') for n1, n2 in G.edges]),\n", + " font_color='red',\n", " )\n", " plt.show()\n", "\n", @@ -5796,25 +5770,25 @@ " def __init__(self):\n", " self.encoded_dim = 0\n", " self.sets = {\n", - " \"symbol\": {\n", - " \"B\",\n", - " \"Br\",\n", - " \"C\",\n", - " \"Ca\",\n", - " \"Cl\",\n", - " \"F\",\n", - " \"H\",\n", - " \"I\",\n", - " \"N\",\n", - " \"Na\",\n", - " \"O\",\n", - " \"P\",\n", - " \"S\",\n", + " 'symbol': {\n", + " 'B',\n", + " 'Br',\n", + " 'C',\n", + " 'Ca',\n", + " 'Cl',\n", + " 'F',\n", + " 'H',\n", + " 'I',\n", + " 'N',\n", + " 'Na',\n", + " 'O',\n", + " 'P',\n", + " 'S',\n", " },\n", - " \"num_hydrogen\": {0, 1, 2, 3, 4, 5, 6, 7, 8},\n", + " 'num_hydrogen': {0, 1, 2, 3, 4, 5, 6, 7, 8},\n", " }\n", " self.reduced_features = [\n", - " \"symbol\"\n", + " 'symbol'\n", " ] # Where unknown variables (not in the set) might occur, these will get a combined bit in the encoded vector\n", " self.one_hot_mapper = {}\n", " for feature in self.sets.keys():\n", @@ -5833,17 +5807,17 @@ " feature_matrix = np.zeros(shape=(G.number_of_nodes(), self.encoded_dim))\n", "\n", " for i in range(G.number_of_nodes()):\n", - " atom = G.nodes()[i][\"atom\"]\n", + " atom = G.nodes()[i]['atom']\n", "\n", - " if not atom.GetSymbol() in self.sets[\"symbol\"]:\n", + " if atom.GetSymbol() not in self.sets['symbol']:\n", " feature_matrix[i][\n", - " self.one_hot_mapper[\"symbol\"][list(self.sets[\"symbol\"])[-1]] + 1\n", + " self.one_hot_mapper['symbol'][list(self.sets['symbol'])[-1]] + 1\n", " ] = 1.0\n", " else:\n", - " feature_matrix[i][self.one_hot_mapper[\"symbol\"][atom.GetSymbol()]] = 1.0\n", + " feature_matrix[i][self.one_hot_mapper['symbol'][atom.GetSymbol()]] = 1.0\n", "\n", " feature_matrix[i][\n", - " self.one_hot_mapper[\"num_hydrogen\"][atom.GetTotalNumHs()]\n", + " self.one_hot_mapper['num_hydrogen'][atom.GetTotalNumHs()]\n", " ] = 1.0\n", "\n", " return feature_matrix\n", @@ -5877,39 +5851,39 @@ "\n", "\n", "def add_dataframe_features(df):\n", - " df[\"graph\"] = df[\"mol\"].apply(mol_to_nx)\n", - " df[\"features\"] = df[\"graph\"].apply(lambda x: node_encoder.encode(x))\n", - " df[\"Xsymbol\"] = df[\"graph\"].apply(\n", - " lambda x: [x.nodes[atom][\"atom_symbol\"] for atom in x.nodes()]\n", + " df['graph'] = df['mol'].apply(mol_to_nx)\n", + " df['features'] = df['graph'].apply(lambda x: node_encoder.encode(x))\n", + " df['Xsymbol'] = df['graph'].apply(\n", + " lambda x: [x.nodes[atom]['atom_symbol'] for atom in x.nodes()]\n", " )\n", - " df[\"Xi\"] = df[\"graph\"].apply(\n", + " df['Xi'] = df['graph'].apply(\n", " lambda x: [\n", - " min(x.nodes[atom][\"atomic_num\"], num_elems - 1) for atom in x.nodes()\n", + " min(x.nodes[atom]['atomic_num'], num_elems - 1) for atom in x.nodes()\n", " ]\n", " )\n", - " df[\"X\"] = df[\"Xi\"].apply(lambda x: to_categorical(x, num_classes=num_elems))\n", - " df[\"A\"] = df[\"graph\"].apply(nx.convert_matrix.to_numpy_matrix)\n", - " df[\"Atilde\"] = df[\"A\"].apply(lambda x: x + np.eye(N=x.shape[0]))\n", - " df[\"Id\"] = df[\"A\"].apply(lambda x: np.eye(N=x.shape[0]))\n", - " df[\"deg\"] = df[\"A\"].apply(\n", + " df['X'] = df['Xi'].apply(lambda x: to_categorical(x, num_classes=num_elems))\n", + " df['A'] = df['graph'].apply(nx.convert_matrix.to_numpy_matrix)\n", + " df['Atilde'] = df['A'].apply(lambda x: x + np.eye(N=x.shape[0]))\n", + " df['Id'] = df['A'].apply(lambda x: np.eye(N=x.shape[0]))\n", + " df['deg'] = df['A'].apply(\n", " lambda x: tf.transpose(\n", " [tf.clip_by_value(tf.reduce_sum(x, axis=-1), 0.0001, 1000.0)]\n", " )\n", " )\n", - " df[\"isAromatic\"] = df[\"graph\"].apply(\n", - " lambda x: np.array([[x.nodes[atom][\"is_aromatic\"] for atom in x.nodes()]]).T\n", + " df['isAromatic'] = df['graph'].apply(\n", + " lambda x: np.array([[x.nodes[atom]['is_aromatic'] for atom in x.nodes()]]).T\n", " )\n", "\n", " # Extras\n", - " df[\"isN\"] = df[\"graph\"].apply(\n", + " df['isN'] = df['graph'].apply(\n", " lambda x: np.array(\n", - " [[int(x.nodes[atom][\"atom_symbol\"] == \"N\") for atom in x.nodes()]]\n", + " [[int(x.nodes[atom]['atom_symbol'] == 'N') for atom in x.nodes()]]\n", " )\n", " )\n", - " df[\"isN_in_radius1\"] = [df.loc[i, \"Atilde\"] * df.loc[i, \"isN\"].T for i in df.index]\n", - " df[\"isN_in_radius1\"] = df[\"isN_in_radius1\"].apply(lambda x: x.clip(0, 1))\n", - " df[\"isN_neighboring\"] = [df.loc[i, \"A\"] * df.loc[i, \"isN\"].T for i in df.index]\n", - " df[\"isN_neighboring\"] = df[\"isN_neighboring\"].apply(lambda x: x.clip(0, 1))\n", + " df['isN_in_radius1'] = [df.loc[i, 'Atilde'] * df.loc[i, 'isN'].T for i in df.index]\n", + " df['isN_in_radius1'] = df['isN_in_radius1'].apply(lambda x: x.clip(0, 1))\n", + " df['isN_neighboring'] = [df.loc[i, 'A'] * df.loc[i, 'isN'].T for i in df.index]\n", + " df['isN_neighboring'] = df['isN_neighboring'].apply(lambda x: x.clip(0, 1))\n", " return df\n", "\n", "\n", @@ -6061,7 +6035,7 @@ " acc, mean_loss = sklearn.metrics.accuracy_score(y_true, y_hat), np.mean(losses)\n", " if verbose:\n", " print(\n", - " \"Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f\"\n", + " 'Node/Edge level accuracy: %.03f; Mean loss: %.03f; Correct molecules: %.03f'\n", " % (acc, mean_loss, correct_mol / data.shape[0])\n", " )\n", "\n", @@ -6087,9 +6061,9 @@ " optimizer.apply_gradients(zip(gradients, variables))\n", "\n", " # Validate loss/acc\n", - " print(\"Epoch %s\" % epoch)\n", + " print('Epoch %s' % epoch)\n", " validate_model(\n", - " data=data_val, model=gnn_model, y_label=\"isAromatic\", verbose=True\n", + " data=data_val, model=gnn_model, y_label='isAromatic', verbose=True\n", " )\n", "\n", " return" @@ -6119,8 +6093,8 @@ }, "outputs": [], "source": [ - "print(\"Test GNN\")\n", - "acc, loss, correct = validate_model(df_test, gnn_model, \"isAromatic\", verbose=True)\n", + "print('Test GNN')\n", + "acc, loss, correct = validate_model(df_test, gnn_model, 'isAromatic', verbose=True)\n", "# print()" ] }, @@ -6162,7 +6136,7 @@ "yhat = gnn_model(d.features, tf.cast(d.A / d.deg, dtype=tf.float32))\n", "ytrue = tf.cast(d.isAromatic, dtype=tf.float32)\n", "\n", - "print(\" y y_hat \")\n", + "print(' y y_hat ')\n", "print(np.round(np.array(tf.concat([ytrue, tf.nn.sigmoid(yhat)], axis=1)), decimals=2))" ] }, @@ -6205,7 +6179,7 @@ " AL[row + i, j] = 1.0\n", " AR[row + i, edges_to[i]] = 1.0\n", " y_edge.append(\n", - " G[j][edges_to[i]][\"bond_type\"].name == \"AROMATIC\"\n", + " G[j][edges_to[i]]['bond_type'].name == 'AROMATIC'\n", " ) # Add y condition here\n", " edge_idx.append((j, edges_to[i]))\n", " row += row_degree\n", @@ -6214,13 +6188,13 @@ "\n", "\n", "def add_dataframe_edge_features(df):\n", - " df[\"AL\"] = df.apply(\n", - " lambda x: compute_helper_matrices(x[\"A\"], x[\"deg\"], x[\"graph\"]), axis=1\n", + " df['AL'] = df.apply(\n", + " lambda x: compute_helper_matrices(x['A'], x['deg'], x['graph']), axis=1\n", " )\n", - " df[\"AR\"] = df[\"AL\"].apply(lambda x: x[1])\n", - " df[\"edges_is_aromatic\"] = df[\"AL\"].apply(lambda x: np.array([x[2]]).T)\n", - " df[\"edges_idx\"] = df[\"AL\"].apply(lambda x: x[3])\n", - " df[\"AL\"] = df[\"AL\"].apply(lambda x: x[0])\n", + " df['AR'] = df['AL'].apply(lambda x: x[1])\n", + " df['edges_is_aromatic'] = df['AL'].apply(lambda x: np.array([x[2]]).T)\n", + " df['edges_idx'] = df['AL'].apply(lambda x: x[3])\n", + " df['AL'] = df['AL'].apply(lambda x: x[0])\n", " return df\n", "\n", "\n", @@ -6285,14 +6259,14 @@ " optimizer.apply_gradients(zip(gradients, variables))\n", "\n", " # Validate loss/acc\n", - " print(\"Epoch %s\" % epoch)\n", + " print('Epoch %s' % epoch)\n", " validate_model(\n", " data=data_val,\n", " model=edge_prediction_model,\n", - " y_label=\"edges_is_aromatic\",\n", + " y_label='edges_is_aromatic',\n", " verbose=True,\n", - " AL=\"AL\",\n", - " AR=\"AR\",\n", + " AL='AL',\n", + " AR='AR',\n", " )\n", "\n", " return" @@ -6324,14 +6298,14 @@ }, "outputs": [], "source": [ - "print(\"Test Edge GNN\")\n", + "print('Test Edge GNN')\n", "acc, loss, correct = validate_model(\n", " data=df_test,\n", " model=edge_prediction_model,\n", - " y_label=\"edges_is_aromatic\",\n", + " y_label='edges_is_aromatic',\n", " verbose=True,\n", - " AL=\"AL\",\n", - " AR=\"AR\",\n", + " AL='AL',\n", + " AR='AR',\n", ")" ] }, @@ -6372,7 +6346,7 @@ "for i in range(tf.cast(d.A, dtype=tf.float32).shape[0]):\n", " for j in range(tf.cast(d.A, dtype=tf.float32).shape[1]):\n", " if d.A[i, j] >= 1:\n", - " print(G[i][j][\"bond\"].GetBondType())" + " print(G[i][j]['bond'].GetBondType())" ] }, { @@ -6417,7 +6391,7 @@ }, "outputs": [], "source": [ - "nicotine_mol = Chem.MolFromSmiles(\"NCCCCC(N)CC(O)=O\")\n", + "nicotine_mol = Chem.MolFromSmiles('NCCCCC(N)CC(O)=O')\n", "print(Descriptors.ExactMolWt(nicotine_mol) + 1)\n", "G = mol_to_nx(nicotine_mol)\n", "draw_graph(G, edge_labels=True)\n", @@ -6507,9 +6481,9 @@ "outputs": [], "source": [ "# spectral library\n", - "nist_msp = mspReader.read(f\"{home}/data/metabolites/MassBank/MassBank_NIST.msp\")\n", + "nist_msp = mspReader.read(f'{home}/data/metabolites/MassBank/MassBank_NIST.msp')\n", "df = pd.DataFrame(nist_msp)\n", - "df[\"mol\"] = df[\"SMILES\"].apply(Chem.MolFromSmiles)\n", + "df['mol'] = df['SMILES'].apply(Chem.MolFromSmiles)\n", "df.dropna(inplace=True)\n", "print(df.shape)" ] @@ -6551,7 +6525,7 @@ "print(df.Precursor_type.unique())\n", "print(df.Instrument_type.unique())\n", "print(df.Collision_energy.unique())\n", - "print(sum(df.Collision_energy == \"30(NCE)\"))\n", + "print(sum(df.Collision_energy == '30(NCE)'))\n", "sns.histplot(df.Collision_energy)\n", "plt.show()" ] @@ -6577,12 +6551,12 @@ }, "outputs": [], "source": [ - "df = df[df.Ion_mode == \"POSITIVE\"]\n", - "df = df[df.Precursor_type == \"[M+H]+\"]\n", - "df = df[df.Spectrum_type == \"MS2\"]\n", - "df = df[df.Instrument_type == \"LC-ESI-ITFT\"]\n", - "df[\"Num Peaks\"] = df[\"Num Peaks\"].astype(int)\n", - "df = df[df[\"Num Peaks\"] > 1]\n", + "df = df[df.Ion_mode == 'POSITIVE']\n", + "df = df[df.Precursor_type == '[M+H]+']\n", + "df = df[df.Spectrum_type == 'MS2']\n", + "df = df[df.Instrument_type == 'LC-ESI-ITFT']\n", + "df['Num Peaks'] = df['Num Peaks'].astype(int)\n", + "df = df[df['Num Peaks'] > 1]\n", "print(df.shape)\n", "# df_nist = df_nist.iloc[:1000] # Reduce to 1000\n", "\n", @@ -6608,8 +6582,8 @@ " [\n", " abs(\n", " (\n", - " float(df.loc[x, \"ExactMass\"])\n", - " - Descriptors.ExactMolWt(df.loc[x, \"mol\"])\n", + " float(df.loc[x, 'ExactMass'])\n", + " - Descriptors.ExactMolWt(df.loc[x, 'mol'])\n", " )\n", " > 0.2\n", " )\n", @@ -6622,8 +6596,8 @@ " [\n", " abs(\n", " (\n", - " float(df.loc[x, \"PrecursorMZ\"])\n", - " - (Descriptors.ExactMolWt(df.loc[x, \"mol\"]) + PROTON_MZ)\n", + " float(df.loc[x, 'PrecursorMZ'])\n", + " - (Descriptors.ExactMolWt(df.loc[x, 'mol']) + PROTON_MZ)\n", " )\n", " > 0.2\n", " )\n", @@ -6636,12 +6610,12 @@ "# Adjust types and assess specific features\n", "#\n", "\n", - "df[\"PrecursorMZ\"] = df[\"PrecursorMZ\"].astype(\"float32\")\n", - "df[\"theoretical_PrecursorMZ\"] = df[\"mol\"].apply(\n", + "df['PrecursorMZ'] = df['PrecursorMZ'].astype('float32')\n", + "df['theoretical_PrecursorMZ'] = df['mol'].apply(\n", " lambda x: Descriptors.ExactMolWt(x) + PROTON_MZ\n", ")\n", - "df[\"PrecursorMZ_difference\"] = df[\"PrecursorMZ\"] - df[\"theoretical_PrecursorMZ\"]\n", - "df[\"absPrecursorMZ_difference\"] = df[\"PrecursorMZ_difference\"].apply(abs)\n", + "df['PrecursorMZ_difference'] = df['PrecursorMZ'] - df['theoretical_PrecursorMZ']\n", + "df['absPrecursorMZ_difference'] = df['PrecursorMZ_difference'].apply(abs)\n", "\n", "df = df[df.absPrecursorMZ_difference < MZ_TOLERANCE]\n", "\n", @@ -6674,9 +6648,10 @@ }, "outputs": [], "source": [ - "from treelib import Node, Tree\n", "from copy import copy\n", "\n", + "from treelib import Node, Tree\n", + "\n", "\n", "def create_fragments(mol, i, j):\n", " try:\n", @@ -6772,7 +6747,7 @@ " [d.mol, t.get_node(20).data, t.get_node(14).data, t.get_node(263).data],\n", " molsPerRow=4,\n", " useSVG=True,\n", - " legends=[\"intact\", \"fragment 1\", \"fragment 2\", \"fragment 3\"],\n", + " legends=['intact', 'fragment 1', 'fragment 2', 'fragment 3'],\n", ")" ] }, @@ -6815,7 +6790,7 @@ " if tree.level(node.identifier) > depth:\n", " continue\n", " mz = node.tag + PROTON_MZ\n", - " peak_idx = match_peak_to_list(mz, peaks[\"mz\"])\n", + " peak_idx = match_peak_to_list(mz, peaks['mz'])\n", " if not np.isnan(peak_idx):\n", " fragments.append(\n", " (\n", @@ -6826,19 +6801,19 @@ " else tree.parent(node.identifier).identifier,\n", " )\n", " )\n", - " if not peaks[\"mz\"][peak_idx] in peak_list:\n", - " peak_list.append(peaks[\"mz\"][peak_idx])\n", - " intensities.append(peaks[\"intensity\"][peak_idx])\n", + " if peaks['mz'][peak_idx] not in peak_list:\n", + " peak_list.append(peaks['mz'][peak_idx])\n", + " intensities.append(peaks['intensity'][peak_idx])\n", "\n", " return fragments, peak_list, intensities\n", "\n", "\n", - "print(\"Spectrum peak list: %s\" % d.peaks[\"mz\"])\n", + "print('Spectrum peak list: %s' % d.peaks['mz'])\n", "DEPTH = 2\n", "frags, p, i = get_matching_peaks_until_depth(tree=t, peaks=d.peaks, depth=DEPTH)\n", - "print(\"Identified [M+H]+ fragments at treedepth %s: %s\" % (DEPTH, p))\n", - "print(\"Fragment intensities %s\" % i)\n", - "print(\"Fragment intensities covered %.02f\" % (sum(i) / sum(d.peaks[\"intensity\"])))\n", + "print('Identified [M+H]+ fragments at treedepth %s: %s' % (DEPTH, p))\n", + "print('Fragment intensities %s' % i)\n", + "print('Fragment intensities covered %.02f' % (sum(i) / sum(d.peaks['intensity'])))\n", "# collections.Counter([x[2] for x in frags])" ] }, @@ -6858,7 +6833,7 @@ " mols,\n", " molsPerRow=4,\n", " useSVG=True,\n", - " legends=[\"mz: %.2f\" % (Chem.Descriptors.ExactMolWt(x) + PROTON_MZ) for x in mols],\n", + " legends=['mz: %.2f' % (Chem.Descriptors.ExactMolWt(x) + PROTON_MZ) for x in mols],\n", ")" ] }, @@ -6873,13 +6848,13 @@ "outputs": [], "source": [ "# 2nd order fragmentation\n", - "print(\"All possible fragments matching: %s\" % frags)\n", + "print('All possible fragments matching: %s' % frags)\n", "for x in p:\n", " parents = []\n", " for f in frags:\n", " if do_peaks_match(x, f[0]):\n", " parents.append(f[2])\n", - " print(\"%s: %s\" % (x, parents))\n", + " print('%s: %s' % (x, parents))\n", "parents = [t.get_node(170).data, t.get_node(263).data, t.get_node(227).data]\n", "mols = [t.get_node(171).data, t.get_node(283).data, t.get_node(232).data]\n", "\n", @@ -6888,7 +6863,7 @@ " m,\n", " molsPerRow=3,\n", " useSVG=True,\n", - " legends=[\"mz: %.2f\" % (Chem.Descriptors.ExactMolWt(x) + PROTON_MZ) for x in m],\n", + " legends=['mz: %.2f' % (Chem.Descriptors.ExactMolWt(x) + PROTON_MZ) for x in m],\n", ")" ] }, @@ -6903,7 +6878,7 @@ "outputs": [], "source": [ "# s = mspReader.get_spectrum_by_name()\n", - "sv.plot_spectrum(d, title=\"Phacidin MS/MS\")" + "sv.plot_spectrum(d, title='Phacidin MS/MS')" ] }, { @@ -6929,17 +6904,17 @@ }, "outputs": [], "source": [ - "test = \"CCC=CC\"\n", + "test = 'CCC=CC'\n", "testmol = Chem.MolFromSmiles(test)\n", "\n", - "test2 = \"CC=CCC\"\n", + "test2 = 'CC=CCC'\n", "testmol2 = Chem.MolFromSmiles(test2)\n", "\n", - "test3 = \"C=CCCC\"\n", + "test3 = 'C=CCCC'\n", "testmol3 = Chem.MolFromSmiles(test3)\n", "\n", "\n", - "print(\"T/F check:\", equalMols(testmol, testmol2), equalMols(testmol2, testmol3))" + "print('T/F check:', equalMols(testmol, testmol2), equalMols(testmol2, testmol3))" ] }, { @@ -6952,7 +6927,7 @@ }, "outputs": [], "source": [ - "s = \"CCCCCCCCC(=O)C1=C(C=C(C(=C1O)C=O)O)O\"\n", + "s = 'CCCCCCCCC(=O)C1=C(C=C(C(=C1O)C=O)O)O'\n", "m = Chem.MolFromSmiles(s)\n", "\n", "m" @@ -6979,7 +6954,7 @@ }, "outputs": [], "source": [ - "df[df.Name == \"Ibuprofen\"]\n", + "df[df.Name == 'Ibuprofen']\n", "d = df.loc[17152]\n", "d.peaks" ] @@ -7012,11 +6987,11 @@ "source": [ "print(zigzagerrorhack)\n", "\n", - "df[\"fragmentation_tree\"] = None\n", - "df[\"identified_peaks\"] = None\n", - "df[\"identified_peaks_intensities\"] = None\n", - "df[\"intensity_covered\"] = None\n", - "df[\"num_identified_peaks\"] = None\n", + "df['fragmentation_tree'] = None\n", + "df['identified_peaks'] = None\n", + "df['identified_peaks_intensities'] = None\n", + "df['intensity_covered'] = None\n", + "df['num_identified_peaks'] = None\n", "\n", "\n", "DEPTH = 2\n", @@ -7024,21 +6999,21 @@ "c = 0\n", "for x in df.index:\n", " if c % 20 == 0:\n", - " print(\"%.02f%%\" % (100 * c / float(df.shape[0])), end=\"\\r\")\n", + " print('%.02f%%' % (100 * c / float(df.shape[0])), end='\\r')\n", " d = df.loc[x]\n", " t = build_fragmentation_tree(d.mol, d.edges_idx, depth=DEPTH)\n", " frags, p, i = get_matching_peaks_until_depth(tree=t, peaks=d.peaks, depth=DEPTH)\n", - " intensity_covered = sum(i) / sum(d.peaks[\"intensity\"])\n", + " intensity_covered = sum(i) / sum(d.peaks['intensity'])\n", "\n", " # FILTER: at least 1 peak other than the precursor, at least 0.5 of the total intensity covered\n", " if len(p) > 1 and intensity_covered > 0.5:\n", - " df.at[x, \"fragmentation_tree\"] = t\n", + " df.at[x, 'fragmentation_tree'] = t\n", " # df_nist.loc[x,'fragmentation_tree'] = t\n", " # df_nist.loc[x]['fragmentation_tree'] = t\n", - " df.at[x, \"identified_peaks\"] = p\n", - " df.at[x, \"identified_peaks_intensities\"] = i\n", - " df.at[x, \"intensity_covered\"] = intensity_covered\n", - " df.at[x, \"num_identified_peaks\"] = len(p)\n", + " df.at[x, 'identified_peaks'] = p\n", + " df.at[x, 'identified_peaks_intensities'] = i\n", + " df.at[x, 'intensity_covered'] = intensity_covered\n", + " df.at[x, 'num_identified_peaks'] = len(p)\n", " else:\n", " to_be_removed.append(x)\n", " c += 1" @@ -7057,9 +7032,9 @@ "print(zigzagerrorhack)\n", "\n", "# df_nist.to_csv('./nist_reduced.csv')\n", - "df.to_csv(f\"{home}/Desktop/nist_reduced.csv\")\n", + "df.to_csv(f'{home}/Desktop/nist_reduced.csv')\n", "\n", - "print(\"yes\")\n", + "print('yes')\n", "print(zigzagerrorhack)" ] }, @@ -7086,7 +7061,7 @@ "source": [ "import ast\n", "\n", - "df = pd.read_csv(\"./nist_reduced.csv\")\n", + "df = pd.read_csv('./nist_reduced.csv')\n", "\n", "df.head()" ] @@ -7101,10 +7076,10 @@ }, "outputs": [], "source": [ - "df = df[~df[\"intensity_covered\"].apply(np.isnan)]\n", - "df[\"peaks\"] = df[\"peaks\"].apply(ast.literal_eval)\n", - "df[\"identified_peaks\"] = df[\"identified_peaks\"].apply(ast.literal_eval)\n", - "df[\"identified_peaks_intensities\"] = df[\"identified_peaks_intensities\"].apply(\n", + "df = df[~df['intensity_covered'].apply(np.isnan)]\n", + "df['peaks'] = df['peaks'].apply(ast.literal_eval)\n", + "df['identified_peaks'] = df['identified_peaks'].apply(ast.literal_eval)\n", + "df['identified_peaks_intensities'] = df['identified_peaks_intensities'].apply(\n", " ast.literal_eval\n", ")\n", "print(df.shape)" @@ -7127,18 +7102,18 @@ " return -1\n", "\n", "\n", - "df[\"precursor_peak_idx\"] = df.apply(\n", - " lambda x: find_precursor_peak_idx(x[\"PrecursorMZ\"], x[\"peaks\"][\"mz\"]), axis=1\n", + "df['precursor_peak_idx'] = df.apply(\n", + " lambda x: find_precursor_peak_idx(x['PrecursorMZ'], x['peaks']['mz']), axis=1\n", ")\n", - "df[\"intensity_covered_without_precursor\"] = df.apply(\n", + "df['intensity_covered_without_precursor'] = df.apply(\n", " lambda x: (\n", " (\n", - " sum(x[\"identified_peaks_intensities\"])\n", - " - x[\"peaks\"][\"intensity\"][x[\"precursor_peak_idx\"]]\n", + " sum(x['identified_peaks_intensities'])\n", + " - x['peaks']['intensity'][x['precursor_peak_idx']]\n", " )\n", " / (\n", - " sum(x[\"peaks\"][\"intensity\"])\n", - " - x[\"peaks\"][\"intensity\"][x[\"precursor_peak_idx\"]]\n", + " sum(x['peaks']['intensity'])\n", + " - x['peaks']['intensity'][x['precursor_peak_idx']]\n", " )\n", " ),\n", " axis=1,\n", @@ -7147,11 +7122,11 @@ "# print((x['identified_peaks_intensities'][0]))\n", "print(\n", " (\n", - " sum(x[\"identified_peaks_intensities\"])\n", - " - x[\"peaks\"][\"intensity\"][x[\"precursor_peak_idx\"]]\n", + " sum(x['identified_peaks_intensities'])\n", + " - x['peaks']['intensity'][x['precursor_peak_idx']]\n", " )\n", - " / (sum(x[\"peaks\"][\"intensity\"]) - x[\"peaks\"][\"intensity\"][x[\"precursor_peak_idx\"]])\n", - ") #\n", + " / (sum(x['peaks']['intensity']) - x['peaks']['intensity'][x['precursor_peak_idx']])\n", + ")\n", "print(x.peaks, x.identified_peaks, x.PrecursorMZ, x.intensity_covered_without_precursor)" ] }, @@ -7165,30 +7140,30 @@ }, "outputs": [], "source": [ - "ax = sns.boxplot(x=df.num_identified_peaks, color=sns.color_palette(\"Paired\")[0])\n", + "ax = sns.boxplot(x=df.num_identified_peaks, color=sns.color_palette('Paired')[0])\n", "ax.set_xlim([1, 15])\n", "plt.show()\n", "\n", - "sns.boxplot(df.intensity_covered, color=sns.color_palette(\"Paired\")[0])\n", + "sns.boxplot(df.intensity_covered, color=sns.color_palette('Paired')[0])\n", "plt.show()\n", "\n", "plt.figure(figsize=(12, 4))\n", "sns.boxplot(\n", - " df.intensity_covered_without_precursor, color=sns.color_palette(\"Paired\")[0]\n", + " df.intensity_covered_without_precursor, color=sns.color_palette('Paired')[0]\n", ")\n", "plt.show()\n", "\n", - "fig, axs = plt.subplots(2, 1, sharex=\"all\", figsize=(8, 6))\n", - "sns.boxplot(df.intensity_covered, color=sns.color_palette(\"Paired\")[0], ax=axs[0])\n", + "fig, axs = plt.subplots(2, 1, sharex='all', figsize=(8, 6))\n", + "sns.boxplot(df.intensity_covered, color=sns.color_palette('Paired')[0], ax=axs[0])\n", "sns.boxplot(\n", " df.intensity_covered_without_precursor,\n", - " color=sns.color_palette(\"Paired\")[2],\n", + " color=sns.color_palette('Paired')[2],\n", " ax=axs[1],\n", ")\n", "\n", "plt.show()\n", "\n", - "ax = sns.boxplot(x=df[\"Num Peaks\"], color=sns.color_palette(\"Paired\")[0])\n", + "ax = sns.boxplot(x=df['Num Peaks'], color=sns.color_palette('Paired')[0])\n", "ax.set_xlim([1, 50])\n", "plt.show()\n", "\n", @@ -7205,8 +7180,8 @@ }, "outputs": [], "source": [ - "df[\"fragmentation_tree\"] = df.apply(\n", - " lambda x: build_fragmentation_tree(x[\"mol\"], x[\"edges_idx\"], depth=2), axis=1\n", + "df['fragmentation_tree'] = df.apply(\n", + " lambda x: build_fragmentation_tree(x['mol'], x['edges_idx'], depth=2), axis=1\n", ")" ] }, @@ -7220,23 +7195,23 @@ }, "outputs": [], "source": [ - "df[\"matching_peaks_d1\"] = df.apply(\n", - " lambda x: get_matching_peaks_until_depth(x[\"fragmentation_tree\"], x[\"peaks\"], 1),\n", + "df['matching_peaks_d1'] = df.apply(\n", + " lambda x: get_matching_peaks_until_depth(x['fragmentation_tree'], x['peaks'], 1),\n", " axis=1,\n", ")\n", - "df[\"matching_peaks_d2\"] = df.apply(\n", - " lambda x: get_matching_peaks_until_depth(x[\"fragmentation_tree\"], x[\"peaks\"], 2),\n", + "df['matching_peaks_d2'] = df.apply(\n", + " lambda x: get_matching_peaks_until_depth(x['fragmentation_tree'], x['peaks'], 2),\n", " axis=1,\n", ")\n", "\n", - "df[\"num_peaks_matched_d1\"] = df[\"matching_peaks_d1\"].apply(lambda x: len(x[1]))\n", - "df[\"num_peaks_matched_d2\"] = df[\"matching_peaks_d2\"].apply(lambda x: len(x[1]))\n", + "df['num_peaks_matched_d1'] = df['matching_peaks_d1'].apply(lambda x: len(x[1]))\n", + "df['num_peaks_matched_d2'] = df['matching_peaks_d2'].apply(lambda x: len(x[1]))\n", "\n", - "df[\"intensity_covered_d1\"] = df.apply(\n", - " lambda x: sum(x[\"matching_peaks_d1\"][2]) / sum(x[\"peaks\"][\"intensity\"]), axis=1\n", + "df['intensity_covered_d1'] = df.apply(\n", + " lambda x: sum(x['matching_peaks_d1'][2]) / sum(x['peaks']['intensity']), axis=1\n", ")\n", - "df[\"intensity_covered_d2\"] = df.apply(\n", - " lambda x: sum(x[\"matching_peaks_d2\"][2]) / sum(x[\"peaks\"][\"intensity\"]), axis=1\n", + "df['intensity_covered_d2'] = df.apply(\n", + " lambda x: sum(x['matching_peaks_d2'][2]) / sum(x['peaks']['intensity']), axis=1\n", ")\n", "# df_nist[['num_peaks_matched_d1', 'intensity_covered_d1', 'num_peaks_matched_d2', 'intensity_covered_d2']]" ] @@ -7251,18 +7226,18 @@ }, "outputs": [], "source": [ - "fig, axs = plt.subplots(1, 2, sharey=\"all\", figsize=(8, 6))\n", - "sns.boxplot(ax=axs[0], y=df[\"num_peaks_matched_d1\"])\n", - "sns.boxplot(ax=axs[1], y=df[\"num_peaks_matched_d2\"])\n", + "fig, axs = plt.subplots(1, 2, sharey='all', figsize=(8, 6))\n", + "sns.boxplot(ax=axs[0], y=df['num_peaks_matched_d1'])\n", + "sns.boxplot(ax=axs[1], y=df['num_peaks_matched_d2'])\n", "plt.show()\n", "\n", "\n", - "fig, axs = plt.subplots(1, 2, sharey=\"all\", figsize=(8, 6))\n", - "sns.boxplot(ax=axs[0], y=df[\"intensity_covered_d1\"], color=\"pink\")\n", - "sns.boxplot(ax=axs[1], y=df[\"intensity_covered_d2\"], color=\"pink\")\n", + "fig, axs = plt.subplots(1, 2, sharey='all', figsize=(8, 6))\n", + "sns.boxplot(ax=axs[0], y=df['intensity_covered_d1'], color='pink')\n", + "sns.boxplot(ax=axs[1], y=df['intensity_covered_d2'], color='pink')\n", "plt.show()\n", "\n", - "print(sum(df[df[\"intensity_covered_d2\"] > 0.5][\"num_peaks_matched_d2\"] > 1))" + "print(sum(df[df['intensity_covered_d2'] > 0.5]['num_peaks_matched_d2'] > 1))" ] }, { @@ -7275,7 +7250,7 @@ }, "outputs": [], "source": [ - "sns.histplot(df, x=\"num_peaks_matched\", bins=range(0, 12))\n", + "sns.histplot(df, x='num_peaks_matched', bins=range(0, 12))\n", "plt.show()\n", "# print(sum(df_nist['num_peaks_matched'] >= 3))\n", "\n", @@ -7298,9 +7273,9 @@ "source": [ "# df_candidates['high_intense'] = df_candidates.apply(lambda x: sum([any([(do_peaks_match(frag_peak, float(x['peaks']['mz'][i])) and int(x['peaks']['intensity'][i]) > 100) for i in range(len(x['peaks']['mz']))]) for frag_peak in x['unique_fragment_mz']]), axis=1)\n", "\n", - "df_candidates = df_candidates[df_candidates[\"high_intense\"] >= 2]\n", + "df_candidates = df_candidates[df_candidates['high_intense'] >= 2]\n", "\n", - "sns.histplot(df_candidates, x=\"high_intense\", bins=range(0, 12))\n", + "sns.histplot(df_candidates, x='high_intense', bins=range(0, 12))\n", "plt.show()" ] }, @@ -7334,13 +7309,13 @@ " f1 = any(\n", " [\n", " do_peaks_match(d.edge_fragment_mz[i][0], float(peak))\n", - " for peak in d[\"peaks\"][\"mz\"]\n", + " for peak in d['peaks']['mz']\n", " ]\n", " )\n", " f2 = any(\n", " [\n", " do_peaks_match(d.edge_fragment_mz[i][1], float(peak))\n", - " for peak in d[\"peaks\"][\"mz\"]\n", + " for peak in d['peaks']['mz']\n", " ]\n", " )\n", "\n", @@ -7371,7 +7346,7 @@ " [d.mol, nm, f[0], f[1]],\n", " molsPerRow=4,\n", " useSVG=True,\n", - " legends=[\"intact\", \"broken\", \"fragment 1\", \"fragment 2\"],\n", + " legends=['intact', 'broken', 'fragment 1', 'fragment 2'],\n", ")" ] }, @@ -7391,7 +7366,7 @@ " [d.mol, nm, f[0], f[1]],\n", " molsPerRow=4,\n", " useSVG=True,\n", - " legends=[\"intact\", \"broken\", \"fragment 1\", \"fragment 2\"],\n", + " legends=['intact', 'broken', 'fragment 1', 'fragment 2'],\n", ")" ] }, @@ -7484,7 +7459,7 @@ " ],\n", " molsPerRow=4,\n", " useSVG=True,\n", - " legends=[\"intact\", \"fragment 1\", \"fragment 2\", \"fragment 3\"],\n", + " legends=['intact', 'fragment 1', 'fragment 2', 'fragment 3'],\n", ")" ] }, @@ -7517,7 +7492,7 @@ "for node in fragmentation_tree.all_nodes():\n", " f = node.data\n", " mz = Chem.Descriptors.ExactMolWt(f) + PROTON_MZ\n", - " if any([do_peaks_match(mz, float(peak)) for peak in d[\"peaks\"][\"mz\"]]):\n", + " if any([do_peaks_match(mz, float(peak)) for peak in d['peaks']['mz']]):\n", " if node.identifier == 0:\n", " precursor_mz.append(mz)\n", " elif fragmentation_tree.parent(node.identifier).identifier == 0:\n", @@ -7535,7 +7510,7 @@ " o1_f,\n", " molsPerRow=4,\n", " useSVG=True,\n", - " legends=[\"intact\", \"fragment 1\", \"fragment 2\", \"fragment 3\"],\n", + " legends=['intact', 'fragment 1', 'fragment 2', 'fragment 3'],\n", ")" ] }, @@ -7549,8 +7524,8 @@ }, "outputs": [], "source": [ - "print(d.peaks[\"mz\"])\n", - "print(d.peaks[\"intensity\"])\n", + "print(d.peaks['mz'])\n", + "print(d.peaks['intensity'])\n", "print(np.unique(parents))" ] }, diff --git a/lib_loader/ms_dial_loader.ipynb b/lib_loader/ms_dial_loader.ipynb index f261461..355b6a8 100644 --- a/lib_loader/ms_dial_loader.ipynb +++ b/lib_loader/ms_dial_loader.ipynb @@ -16,32 +16,30 @@ "source": [ "import sys\n", "\n", - "print(f\"Working with Python {sys.version}\")\n", + "print(f'Working with Python {sys.version}')\n", "\n", - "import pandas as pd\n", - "from rdkit import Chem\n", - "from rdkit.Chem import AllChem\n", - "from rdkit.Chem import Draw\n", - "import rdkit.Chem.Descriptors as Descriptors\n", - "from rdkit.Chem import PandasTools\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", "import os\n", "\n", + "import matplotlib.pyplot as plt\n", + "\n", "# import pymzml\n", "import numpy as np\n", - "from rdkit import RDLogger\n", + "import pandas as pd\n", + "import rdkit.Chem.Descriptors as Descriptors\n", + "import seaborn as sns\n", + "from rdkit import Chem, RDLogger\n", + "from rdkit.Chem import AllChem, Draw, PandasTools\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "import fiora.IO.mspReader as mspReader\n", + "home = expanduser('~')\n", "import fiora.IO.mgfReader as mgfReader\n", + "import fiora.IO.mspReader as mspReader\n", "import fiora.visualization.spectrum_visualizer as sv" ] }, @@ -53,12 +51,12 @@ "source": [ "p = pd.DataFrame(\n", " mspReader.read(\n", - " f\"{home}/data/metabolites/MS_DIAL/MSMS_Public_EXP_Pos_VS17.msp\", sep=\"\\t\"\n", + " f'{home}/data/metabolites/MS_DIAL/MSMS_Public_EXP_Pos_VS17.msp', sep='\\t'\n", " )\n", ")\n", "n = pd.DataFrame(\n", " mspReader.read(\n", - " f\"{home}/data/metabolites/MS_DIAL/MSMS_Public_EXP_NEG_VS17.msp\", sep=\"\\t\"\n", + " f'{home}/data/metabolites/MS_DIAL/MSMS_Public_EXP_NEG_VS17.msp', sep='\\t'\n", " )\n", ")\n", "df = pd.concat([p, n])" @@ -114,7 +112,7 @@ } ], "source": [ - "df[df[\"CE_float\"]][\"INSTRUMENTTYPE\"].value_counts()" + "df[df['CE_float']]['INSTRUMENTTYPE'].value_counts()" ] }, { @@ -123,13 +121,13 @@ "metadata": {}, "outputs": [], "source": [ - "df[\"origin\"] = \"Untracked\"\n", + "df['origin'] = 'Untracked'\n", "for i, d in df.iterrows():\n", " try:\n", - " df.at[i, \"origin\"] = d[\"COMMENT\"].split(\"origin=\")[1]\n", + " df.at[i, 'origin'] = d['COMMENT'].split('origin=')[1]\n", " except:\n", - " if \"MetaboBASE\" in d[\"COMMENT\"]:\n", - " df.at[i, \"origin\"] = \"MetaboBASE\"" + " if 'MetaboBASE' in d['COMMENT']:\n", + " df.at[i, 'origin'] = 'MetaboBASE'" ] }, { @@ -162,7 +160,7 @@ } ], "source": [ - "df[\"origin\"].value_counts()" + "df['origin'].value_counts()" ] }, { @@ -195,7 +193,7 @@ ], "source": [ "sns.histplot(\n", - " df[df[\"origin\"] == \"Vaniya/Fiehn Natural Products Library\"][\"RETENTIONTIME\"].astype(\n", + " df[df['origin'] == 'Vaniya/Fiehn Natural Products Library']['RETENTIONTIME'].astype(\n", " float\n", " )\n", ")" @@ -207,7 +205,7 @@ "metadata": {}, "outputs": [], "source": [ - "df[\"RETENTIONTIME\"] = df[\"RETENTIONTIME\"].astype(float)\n", + "df['RETENTIONTIME'] = df['RETENTIONTIME'].astype(float)\n", "df.reset_index(inplace=True)" ] }, @@ -250,16 +248,16 @@ ], "source": [ "potential_homogenous_RT_libs = [\n", - " \"RIKEN Plant Specialized Metabolome Annotation (PlaSMA) Authentic Standard Library\",\n", - " \"BMDMS-NP\",\n", + " 'RIKEN Plant Specialized Metabolome Annotation (PlaSMA) Authentic Standard Library',\n", + " 'BMDMS-NP',\n", "] # , 'Vaniya/Fiehn Natural Products Library']#, \"Global Natural Product Social Molecular Networking Library\"]\n", "\n", "sns.histplot(\n", - " data=df[df[\"origin\"].isin(potential_homogenous_RT_libs)],\n", - " x=\"RETENTIONTIME\",\n", - " hue=\"origin\",\n", + " data=df[df['origin'].isin(potential_homogenous_RT_libs)],\n", + " x='RETENTIONTIME',\n", + " hue='origin',\n", " common_norm=False,\n", - " stat=\"density\",\n", + " stat='density',\n", ")\n", "plt.xlim([0, 20])" ] @@ -295,9 +293,9 @@ ], "source": [ "sns.histplot(\n", - " df[\"SMILES\"],\n", + " df['SMILES'],\n", ")\n", - "df[\"origin\"].value_counts()" + "df['origin'].value_counts()" ] }, { @@ -306,7 +304,7 @@ "metadata": {}, "outputs": [], "source": [ - "BMDMS = df[df[\"origin\"] == \"BMDMS-NP\"]" + "BMDMS = df[df['origin'] == 'BMDMS-NP']" ] }, { @@ -326,8 +324,8 @@ } ], "source": [ - "print(len(BMDMS[\"SMILES\"].unique()))\n", - "print(BMDMS[\"PRECURSORTYPE\"].value_counts())" + "print(len(BMDMS['SMILES'].unique()))\n", + "print(BMDMS['PRECURSORTYPE'].value_counts())" ] }, { @@ -359,7 +357,7 @@ } ], "source": [ - "sns.histplot(BMDMS[\"RETENTIONTIME\"].astype(float))" + "sns.histplot(BMDMS['RETENTIONTIME'].astype(float))" ] }, { @@ -368,9 +366,9 @@ "metadata": {}, "outputs": [], "source": [ - "precursor_types = [\"[M+H]+\", \"[M-H]-\"]\n", + "precursor_types = ['[M+H]+', '[M-H]-']\n", "\n", - "df = df[df[\"PRECURSORTYPE\"].apply(lambda x: x in precursor_types)]" + "df = df[df['PRECURSORTYPE'].apply(lambda x: x in precursor_types)]" ] }, { @@ -381,10 +379,10 @@ "source": [ "from modules.MOL.collision_energy import align_CE\n", "\n", - "df[\"PRECURSORMZ\"] = df[\"PRECURSORMZ\"].astype(float)\n", - "df[\"CE\"] = df.apply(\n", + "df['PRECURSORMZ'] = df['PRECURSORMZ'].astype(float)\n", + "df['CE'] = df.apply(\n", " lambda x: align_CE(\n", - " x[\"COLLISIONENERGY\"], x[\"PRECURSORMZ\"], instrument=x[\"INSTRUMENTTYPE\"]\n", + " x['COLLISIONENERGY'], x['PRECURSORMZ'], instrument=x['INSTRUMENTTYPE']\n", " ),\n", " axis=1,\n", ")" @@ -407,9 +405,9 @@ } ], "source": [ - "df[\"CE_type\"] = df[\"CE\"].apply(type)\n", - "print(df[\"CE_type\"].value_counts())\n", - "print(sum(df[\"CE\"] == \"\")) # TODO assign random CE or 35 NCE ??)" + "df['CE_type'] = df['CE'].apply(type)\n", + "print(df['CE_type'].value_counts())\n", + "print(sum(df['CE'] == '')) # TODO assign random CE or 35 NCE ??)" ] }, { @@ -449,8 +447,8 @@ } ], "source": [ - "ce = df[df[\"CE_type\"] == str]\n", - "ce[\"CE\"].value_counts()[:20]" + "ce = df[df['CE_type'] == str]\n", + "ce['CE'].value_counts()[:20]" ] }, { @@ -474,7 +472,7 @@ } ], "source": [ - "ce[(ce[\"CE\"] == \"\")][\"origin\"].value_counts()" + "ce[(ce['CE'] == '')]['origin'].value_counts()" ] }, { @@ -506,7 +504,7 @@ } ], "source": [ - "sv.plot_spectrum(ce[(ce[\"CE\"] == \"\")].iloc[0])" + "sv.plot_spectrum(ce[(ce['CE'] == '')].iloc[0])" ] }, { @@ -515,9 +513,9 @@ "metadata": {}, "outputs": [], "source": [ - "df = df[df[\"CE_type\"] == float]\n", - "df[\"CE\"] = df[\"CE\"].astype(float)\n", - "df = df[df[\"CE\"] <= 1000.0]\n", + "df = df[df['CE_type'] == float]\n", + "df['CE'] = df['CE'].astype(float)\n", + "df = df[df['CE'] <= 1000.0]\n", "df.reset_index(inplace=True)" ] }, @@ -547,7 +545,7 @@ } ], "source": [ - "sns.displot(data=df, x=\"CE\", binwidth=5, kde=False)\n", + "sns.displot(data=df, x='CE', binwidth=5, kde=False)\n", "plt.xlim([0, 150])\n", "plt.show()\n", "print(df.shape)" @@ -582,12 +580,12 @@ "fig, axs = plt.subplots(1, 1, figsize=(12.8, 6.4), sharey=False)\n", "\n", "\n", - "top_instrumenttypes = df[\"INSTRUMENTTYPE\"].value_counts().head(6).index\n", + "top_instrumenttypes = df['INSTRUMENTTYPE'].value_counts().head(6).index\n", "sns.histplot(\n", - " data=df[df[\"INSTRUMENTTYPE\"].isin(top_instrumenttypes)],\n", - " x=\"CE\",\n", - " hue=\"INSTRUMENTTYPE\",\n", - " multiple=\"stack\",\n", + " data=df[df['INSTRUMENTTYPE'].isin(top_instrumenttypes)],\n", + " x='CE',\n", + " hue='INSTRUMENTTYPE',\n", + " multiple='stack',\n", " binwidth=5,\n", " kde=False,\n", ")\n", @@ -604,7 +602,7 @@ "source": [ "from modules.MOL.Metabolite import Metabolite\n", "\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)" + "df['Metabolite'] = df['SMILES'].apply(Metabolite)" ] }, { @@ -635,18 +633,18 @@ "source": [ "from modules.MOL.constants import ADDUCT_WEIGHTS, PPM\n", "\n", - "df[\"Precursor_offset\"] = df[\"PRECURSORMZ\"] - df.apply(\n", - " lambda x: x[\"Metabolite\"].ExactMolWeight + ADDUCT_WEIGHTS[x[\"PRECURSORTYPE\"]],\n", + "df['Precursor_offset'] = df['PRECURSORMZ'] - df.apply(\n", + " lambda x: x['Metabolite'].ExactMolWeight + ADDUCT_WEIGHTS[x['PRECURSORTYPE']],\n", " axis=1,\n", ")\n", - "df[\"Precursor_abs_error\"] = abs(df[\"Precursor_offset\"])\n", - "df[\"Precursor_rel_error\"] = df[\"Precursor_abs_error\"] / df[\"PRECURSORMZ\"]\n", - "df[\"Precursor_ppm_error\"] = df[\"Precursor_abs_error\"] / (df[\"PRECURSORMZ\"] * PPM)\n", - "print((df[\"Precursor_ppm_error\"] > 500).sum())\n", - "d = df[df[\"Precursor_ppm_error\"] > 500].iloc[0]\n", - "print(d[\"Metabolite\"].ExactMolWeight)\n", - "print(d[\"Metabolite\"])\n", - "df[\"SMILES\"].apply(lambda x: \"+\" in x).sum()" + "df['Precursor_abs_error'] = abs(df['Precursor_offset'])\n", + "df['Precursor_rel_error'] = df['Precursor_abs_error'] / df['PRECURSORMZ']\n", + "df['Precursor_ppm_error'] = df['Precursor_abs_error'] / (df['PRECURSORMZ'] * PPM)\n", + "print((df['Precursor_ppm_error'] > 500).sum())\n", + "d = df[df['Precursor_ppm_error'] > 500].iloc[0]\n", + "print(d['Metabolite'].ExactMolWeight)\n", + "print(d['Metabolite'])\n", + "df['SMILES'].apply(lambda x: '+' in x).sum()" ] }, { @@ -656,20 +654,20 @@ "outputs": [], "source": [ "%%capture\n", - "from modules.MOL.Metabolite import Metabolite\n", "from modules.MOL.constants import PPM\n", + "from modules.MOL.Metabolite import Metabolite\n", "\n", "TOLERANCE = 200 * PPM\n", "\n", "\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", - "df = df[df[\"Metabolite\"].apply(lambda x: x.is_single_connected_structure())]\n", - "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes())\n", - "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL())\n", + "df['Metabolite'] = df['SMILES'].apply(Metabolite)\n", + "df = df[df['Metabolite'].apply(lambda x: x.is_single_connected_structure())]\n", + "df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", + "df['Metabolite'].apply(lambda x: x.compute_graph_attributes())\n", + "df['Metabolite'].apply(lambda x: x.fragment_MOL())\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=TOLERANCE\n", " ),\n", " axis=1,\n", ")" @@ -701,18 +699,18 @@ "outputs": [], "source": [ "# Define figure styles\n", - "color_palette = sns.color_palette(\"magma_r\", 8)\n", + "color_palette = sns.color_palette('magma_r', 8)\n", "sns.set_theme(\n", - " style=\"whitegrid\",\n", + " style='whitegrid',\n", " rc={\n", - " \"axes.edgecolor\": \"black\",\n", - " \"ytick.left\": True,\n", - " \"xtick.bottom\": True,\n", - " \"xtick.color\": \"black\",\n", - " \"axes.spines.bottom\": True,\n", - " \"axes.spines.right\": True,\n", - " \"axes.spines.top\": True,\n", - " \"axes.spines.left\": True,\n", + " 'axes.edgecolor': 'black',\n", + " 'ytick.left': True,\n", + " 'xtick.bottom': True,\n", + " 'xtick.color': 'black',\n", + " 'axes.spines.bottom': True,\n", + " 'axes.spines.right': True,\n", + " 'axes.spines.top': True,\n", + " 'axes.spines.left': True,\n", " },\n", ")" ] @@ -744,23 +742,23 @@ "from modules.MOL.mol_graph import draw_graph\n", "\n", "x = df.iloc[0]\n", - "x_mol = x[\"Metabolite\"].MOL\n", + "x_mol = x['Metabolite'].MOL\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1]}, sharey=False\n", ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(x[\"NAME\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].set_title(x['NAME'] + ' structure:\\n' + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", + "axs[0].axis('off')\n", "\n", - "g_img = draw_graph(x[\"Metabolite\"].Graph, ax=axs[1])\n", - "print(x[\"peaks\"])" + "g_img = draw_graph(x['Metabolite'].Graph, ax=axs[1])\n", + "print(x['peaks'])" ] }, { @@ -828,15 +826,15 @@ "source": [ "from modules.MOL.constants import DEFAULT_MODES\n", "\n", - "df[\"peak_matches\"] = df[\"Metabolite\"].apply(lambda x: getattr(x, \"peak_matches\"))\n", - "df[\"num_peaks_matched\"] = df[\"peak_matches\"].apply(len)\n", + "df['peak_matches'] = df['Metabolite'].apply(lambda x: getattr(x, 'peak_matches'))\n", + "df['num_peaks_matched'] = df['peak_matches'].apply(len)\n", "\n", "\n", "def get_match_stats(matches, mode_count={m: 0 for m in DEFAULT_MODES}):\n", " num_unique, num_conflicts = 0, 0\n", " for mz, match_data in matches.items():\n", " # candidates = match_data[\"fragments\"]\n", - " ion_modes = match_data[\"ion_modes\"]\n", + " ion_modes = match_data['ion_modes']\n", " if len(ion_modes) == 1:\n", " num_unique += 1\n", " elif len(ion_modes) > 1:\n", @@ -846,16 +844,16 @@ " return num_unique, num_conflicts, mode_count\n", "\n", "\n", - "df[\"match_stats\"] = df[\"peak_matches\"].apply(lambda x: get_match_stats(x))\n", - "df[\"num_unique_peaks_matched\"] = df.apply(lambda x: x[\"match_stats\"][0], axis=1)\n", - "df[\"num_conflicts_in_peak_matching\"] = df.apply(lambda x: x[\"match_stats\"][1], axis=1)\n", - "df[\"match_mode_counts\"] = df.apply(lambda x: x[\"match_stats\"][2], axis=1)\n", - "u = df[\"num_unique_peaks_matched\"].sum()\n", - "s = df[\"num_conflicts_in_peak_matching\"].sum()\n", + "df['match_stats'] = df['peak_matches'].apply(lambda x: get_match_stats(x))\n", + "df['num_unique_peaks_matched'] = df.apply(lambda x: x['match_stats'][0], axis=1)\n", + "df['num_conflicts_in_peak_matching'] = df.apply(lambda x: x['match_stats'][1], axis=1)\n", + "df['match_mode_counts'] = df.apply(lambda x: x['match_stats'][2], axis=1)\n", + "u = df['num_unique_peaks_matched'].sum()\n", + "s = df['num_conflicts_in_peak_matching'].sum()\n", "print(\n", - " f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))\"\n", + " f'Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))'\n", ")\n", - "print(f\"Total number of conflicting peak to fragment matches: {s}\")" + "print(f'Total number of conflicting peak to fragment matches: {s}')" ] }, { @@ -877,36 +875,36 @@ "source": [ "fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", - "fig.suptitle(f\"Identified peaks with fragment offset\")\n", + "fig.suptitle('Identified peaks with fragment offset')\n", "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", "sns.histplot(\n", " ax=axs[0],\n", " data=df,\n", - " x=\"num_peaks_matched\",\n", + " x='num_peaks_matched',\n", " color=color_palette[0],\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " bins=range(0, 20, 1),\n", ")\n", "# axs[0].set_ylim(-0.5, 10)\n", - "axs[0].set_ylabel(\"peaks identified\")\n", + "axs[0].set_ylabel('peaks identified')\n", "\n", "\n", - "sns.boxplot(ax=axs[1], data=df, y=\"num_unique_peaks_matched\", color=color_palette[1])\n", + "sns.boxplot(ax=axs[1], data=df, y='num_unique_peaks_matched', color=color_palette[1])\n", "axs[1].set_ylim(-0.5, 15)\n", - "axs[1].set_xlabel(\"unique matches\")\n", - "axs[1].set_ylabel(\"\")\n", + "axs[1].set_xlabel('unique matches')\n", + "axs[1].set_ylabel('')\n", "\n", "\n", "sns.histplot(\n", " ax=axs[2],\n", " data=df,\n", - " x=\"num_conflicts_in_peak_matching\",\n", + " x='num_conflicts_in_peak_matching',\n", " color=color_palette[3],\n", " binwidth=1,\n", ")\n", "# axs[2].set_ylim(-0.5, 1000)\n", - "axs[2].set_xlabel(\"conflicts\")\n", - "axs[2].set_ylabel(\"\")\n", + "axs[2].set_xlabel('conflicts')\n", + "axs[2].set_ylabel('')\n", "\n", "plt.show()" ] @@ -938,21 +936,21 @@ " mode_counts[mode] += m[mode]\n", "\n", "\n", - "df[\"match_mode_counts\"].apply(update_mode_counts)\n", + "df['match_mode_counts'].apply(update_mode_counts)\n", "\n", "sns.barplot(\n", " ax=axs[0],\n", " x=list(mode_counts.keys()),\n", " y=[mode_counts[k] for k in mode_counts.keys()],\n", " palette=color_palette,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " linewidth=1.5,\n", ")\n", "axs[1].pie(\n", " [mode_counts[k] for k in mode_counts.keys()],\n", " labels=list(mode_counts.keys()),\n", " colors=color_palette,\n", - " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + " wedgeprops={'edgecolor': 'black', 'linewidth': 1.5},\n", ")\n", "\n", "plt.show()" @@ -986,19 +984,19 @@ "import pandas as pd\n", "from modules.MOL.Metabolite import Metabolite\n", "\n", - "casmi16_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_challenges_combined.csv\"\n", + "casmi16_path = f'{home}/data/metabolites/CASMI_2016/casmi16_challenges_combined.csv'\n", "df_cas = pd.read_csv(casmi16_path, index_col=[0], low_memory=False)\n", - "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", - "df.dropna(subset=[\"SMILES\"], inplace=True)\n", - "df[\"in_casmi2016\"] = False\n", + "df_cas['Metabolite'] = df_cas['SMILES'].apply(Metabolite)\n", + "df.dropna(subset=['SMILES'], inplace=True)\n", + "df['in_casmi2016'] = False\n", "\n", "for i, d in df_cas.iterrows():\n", - " m = d[\"Metabolite\"]\n", + " m = d['Metabolite']\n", "\n", " for x, D in df.iterrows():\n", - " M = D[\"Metabolite\"]\n", + " M = D['Metabolite']\n", " if m == M:\n", - " df.at[x, \"in_casmi2016\"] = True\n", + " df.at[x, 'in_casmi2016'] = True\n", "del df_cas" ] }, @@ -1023,8 +1021,8 @@ "source": [ "for i in range(0, 6):\n", " print(\n", - " f\"Minimum {i} unique peaks identified (including precursors): \",\n", - " (df[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " f'Minimum {i} unique peaks identified (including precursors): ',\n", + " (df['num_unique_peaks_matched'] >= i).sum(),\n", " )" ] }, @@ -1065,7 +1063,7 @@ } ], "source": [ - "df[(df[\"CE\"] < 1)][\"COLLISIONENERGY\"]" + "df[(df['CE'] < 1)]['COLLISIONENERGY']" ] }, { @@ -1095,7 +1093,7 @@ } ], "source": [ - "df[(df[\"CE\"] < 1)][\"COLLISIONENERGY\"]" + "df[(df['CE'] < 1)]['COLLISIONENERGY']" ] }, { @@ -1122,15 +1120,15 @@ ], "source": [ "save_df = False\n", - "lib = f\"{home}/data/metabolites/MS_DIAL/\"\n", - "name = \"ms_dial_filtered\"\n", - "date = \"XXX\" # \"mid_08_2023\" #\"mid_08_2023\" #\"07_2023\"\n", + "lib = f'{home}/data/metabolites/MS_DIAL/'\n", + "name = 'ms_dial_filtered'\n", + "date = 'XXX' # \"mid_08_2023\" #\"mid_08_2023\" #\"07_2023\"\n", "min_peaks = 5\n", "\n", "if save_df:\n", - " file = lib + name + \"_min\" + str(min_peaks) + \"_\" + date + \".csv\"\n", - " print(\"saving to \", file)\n", - " df[df[\"num_unique_peaks_matched\"] >= min_peaks].to_csv(file)\n", + " file = lib + name + '_min' + str(min_peaks) + '_' + date + '.csv'\n", + " print('saving to ', file)\n", + " df[df['num_unique_peaks_matched'] >= min_peaks].to_csv(file)\n", "\n", " # df.to_csv(lib + name + \"_all\" + \"_\" + date + \".csv\") #TODO HERE" ] @@ -1152,7 +1150,7 @@ } ], "source": [ - "sns.displot(df, x=\"PRECURSORMZ\", kde=True, binwidth=50)\n", + "sns.displot(df, x='PRECURSORMZ', kde=True, binwidth=50)\n", "plt.show()" ] }, @@ -1185,8 +1183,8 @@ } ], "source": [ - "df[\"RETENTIONTIME\"] = df[\"RETENTIONTIME\"].astype(float)\n", - "sns.displot(df, x=\"RETENTIONTIME\", kde=True, binwidth=0.5)\n", + "df['RETENTIONTIME'] = df['RETENTIONTIME'].astype(float)\n", + "sns.displot(df, x='RETENTIONTIME', kde=True, binwidth=0.5)\n", "plt.show()" ] }, @@ -1208,23 +1206,23 @@ ], "source": [ "sns.histplot(\n", - " df[~df[\"in_casmi2016\"]],\n", - " x=\"RETENTIONTIME\",\n", + " df[~df['in_casmi2016']],\n", + " x='RETENTIONTIME',\n", " kde=True,\n", " binwidth=1,\n", - " stat=\"density\",\n", - " multiple=\"stack\",\n", + " stat='density',\n", + " multiple='stack',\n", ")\n", "sns.histplot(\n", - " df[df[\"in_casmi2016\"]],\n", - " x=\"RETENTIONTIME\",\n", + " df[df['in_casmi2016']],\n", + " x='RETENTIONTIME',\n", " kde=True,\n", " binwidth=1,\n", - " stat=\"density\",\n", - " multiple=\"stack\",\n", - " color=\"orange\",\n", + " stat='density',\n", + " multiple='stack',\n", + " color='orange',\n", ")\n", - "plt.legend(labels=[\"Non-Casmi\", \"Casmi2016\"])\n", + "plt.legend(labels=['Non-Casmi', 'Casmi2016'])\n", "plt.xlim([0, 30])\n", "plt.show()" ] @@ -1258,25 +1256,25 @@ } ], "source": [ - "df[\"CCS\"] = df[\"CCS\"].astype(float)\n", + "df['CCS'] = df['CCS'].astype(float)\n", "sns.histplot(\n", - " df[~df[\"in_casmi2016\"]],\n", - " x=\"CCS\",\n", + " df[~df['in_casmi2016']],\n", + " x='CCS',\n", " kde=True,\n", " binwidth=10,\n", - " stat=\"density\",\n", - " multiple=\"stack\",\n", + " stat='density',\n", + " multiple='stack',\n", ")\n", "sns.histplot(\n", - " df[df[\"in_casmi2016\"]],\n", - " x=\"CCS\",\n", + " df[df['in_casmi2016']],\n", + " x='CCS',\n", " kde=True,\n", " binwidth=10,\n", - " stat=\"density\",\n", - " multiple=\"stack\",\n", - " color=\"orange\",\n", + " stat='density',\n", + " multiple='stack',\n", + " color='orange',\n", ")\n", - "plt.legend(labels=[\"Non-Casmi\", \"Casmi2016\"])\n", + "plt.legend(labels=['Non-Casmi', 'Casmi2016'])\n", "\n", "plt.show()" ] @@ -1303,11 +1301,11 @@ "sns.histplot(\n", " ax=ax,\n", " data=df,\n", - " x=\"RETENTIONTIME\",\n", - " hue=\"origin\",\n", - " multiple=\"stack\",\n", + " x='RETENTIONTIME',\n", + " hue='origin',\n", + " multiple='stack',\n", " binwidth=1,\n", - " stat=\"probability\",\n", + " stat='probability',\n", ")\n", "plt.xlim([0, 20])\n", "plt.show()" @@ -1333,9 +1331,9 @@ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", "sns.kdeplot(\n", - " ax=ax, data=df, x=\"RETENTIONTIME\", hue=\"origin\", multiple=\"fill\", common_norm=False\n", + " ax=ax, data=df, x='RETENTIONTIME', hue='origin', multiple='fill', common_norm=False\n", ")\n", - "ax.legend(bbox_to_anchor=(1.5, 0.8), labels=df[\"origin\"].unique())\n", + "ax.legend(bbox_to_anchor=(1.5, 0.8), labels=df['origin'].unique())\n", "plt.xlim([0, 30])\n", "plt.show()" ] @@ -1360,9 +1358,9 @@ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", "sns.kdeplot(\n", - " ax=ax, data=df, x=\"RETENTIONTIME\", hue=\"origin\", multiple=\"layer\", common_norm=False\n", + " ax=ax, data=df, x='RETENTIONTIME', hue='origin', multiple='layer', common_norm=False\n", ")\n", - "plt.legend(labels=df[\"origin\"].unique())\n", + "plt.legend(labels=df['origin'].unique())\n", "plt.xlim([0, 30])\n", "plt.show()" ] @@ -1395,7 +1393,7 @@ } ], "source": [ - "df[\"origin\"].value_counts()" + "df['origin'].value_counts()" ] }, { @@ -1413,8 +1411,8 @@ } ], "source": [ - "print(len(df[\"SMILES\"].unique()))\n", - "print(len(df[df[\"origin\"] == \"BMDMS-NP\"][\"SMILES\"].unique()))" + "print(len(df['SMILES'].unique()))\n", + "print(len(df[df['origin'] == 'BMDMS-NP']['SMILES'].unique()))" ] }, { @@ -1444,7 +1442,7 @@ } ], "source": [ - "sns.histplot(df[df[\"origin\"] == \"BMDMS-NP\"][\"SMILES\"].value_counts())" + "sns.histplot(df[df['origin'] == 'BMDMS-NP']['SMILES'].value_counts())" ] }, { @@ -1474,7 +1472,7 @@ } ], "source": [ - "sns.histplot(df[df[\"origin\"] != \"BMDMS-NP\"][\"SMILES\"].value_counts())" + "sns.histplot(df[df['origin'] != 'BMDMS-NP']['SMILES'].value_counts())" ] }, { @@ -1505,7 +1503,7 @@ ], "source": [ "# sns.catplot(data=df, x=\"origin\")\n", - "sns.histplot(data=df, x=\"origin\")" + "sns.histplot(data=df, x='origin')" ] }, { @@ -1535,12 +1533,12 @@ "fig, axs = plt.subplots(1, 1, figsize=(12.8, 6.4), sharey=False)\n", "\n", "\n", - "top_instrumenttypes = df[\"INSTRUMENTTYPE\"].value_counts().head(6).index\n", + "top_instrumenttypes = df['INSTRUMENTTYPE'].value_counts().head(6).index\n", "sns.histplot(\n", - " data=df[df[\"INSTRUMENTTYPE\"].isin(top_instrumenttypes)],\n", - " x=\"CE\",\n", - " hue=\"INSTRUMENTTYPE\",\n", - " multiple=\"stack\",\n", + " data=df[df['INSTRUMENTTYPE'].isin(top_instrumenttypes)],\n", + " x='CE',\n", + " hue='INSTRUMENTTYPE',\n", + " multiple='stack',\n", " binwidth=5,\n", " kde=False,\n", ")\n", diff --git a/lib_loader/msnlib_loader.ipynb b/lib_loader/msnlib_loader.ipynb index 8149513..af8b869 100644 --- a/lib_loader/msnlib_loader.ipynb +++ b/lib_loader/msnlib_loader.ipynb @@ -16,31 +16,29 @@ "source": [ "import sys\n", "\n", - "print(f\"Working with Python {sys.version}\")\n", + "print(f'Working with Python {sys.version}')\n", "\n", + "import ast\n", + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "import pandas as pd\n", - "from rdkit import Chem\n", - "from rdkit.Chem import AllChem\n", - "from rdkit.Chem import Draw\n", "import rdkit.Chem.Descriptors as Descriptors\n", - "from rdkit.Chem import PandasTools\n", - "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", - "import os\n", - "import ast\n", - "import numpy as np\n", - "from rdkit import RDLogger\n", + "from rdkit import Chem, RDLogger\n", + "from rdkit.Chem import AllChem, Draw, PandasTools\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "import fiora.IO.mspReader as mspReader\n", + "home = expanduser('~')\n", "import fiora.IO.mgfReader as mgfReader\n", + "import fiora.IO.mspReader as mspReader\n", "import fiora.visualization.spectrum_visualizer as sv" ] }, @@ -75,18 +73,18 @@ } ], "source": [ - "version = \"v7\"\n", - "path: str = f\"{home}/data/metabolites/MSnLib/{version}/\"\n", + "version = 'v7'\n", + "path: str = f'{home}/data/metabolites/MSnLib/{version}/'\n", "\n", "dfs = []\n", "for filename in os.listdir(path):\n", - " if filename.endswith(\"ms2.mgf\"):\n", + " if filename.endswith('ms2.mgf'):\n", " filepath = path + filename\n", - " print(f\"Reading {filename}\")\n", + " print(f'Reading {filename}')\n", " df = pd.DataFrame(mgfReader.read(path + filename))\n", - " df[\"file\"] = filename\n", - " df[\"lib\"] = \"MSnLib\"\n", - " df[\"origin\"] = filename.split(\"_\")[1]\n", + " df['file'] = filename\n", + " df['lib'] = 'MSnLib'\n", + " df['origin'] = filename.split('_')[1]\n", " dfs.append(df)\n", "\n", "df = pd.concat(dfs)\n", @@ -107,7 +105,7 @@ } ], "source": [ - "print(f\"Total number of data entries: {len(df)}\")" + "print(f'Total number of data entries: {len(df)}')" ] }, { @@ -116,12 +114,12 @@ "metadata": {}, "outputs": [], "source": [ - "delim = \", \" if version == \"v5\" else \",\"\n", - "df[\"CE_steps\"] = df[\"COLLISION_ENERGY\"].apply(\n", - " lambda x: [float(v) for v in x.strip(\"[]\").split(delim)] if \"[\" in x else [float(x)]\n", + "delim = ', ' if version == 'v5' else ','\n", + "df['CE_steps'] = df['COLLISION_ENERGY'].apply(\n", + " lambda x: [float(v) for v in x.strip('[]').split(delim)] if '[' in x else [float(x)]\n", ")\n", - "df[\"Num_steps\"] = df[\"CE_steps\"].apply(len)\n", - "df[\"CE\"] = df[\"CE_steps\"].apply(lambda x: sum(x) / len(x))" + "df['Num_steps'] = df['CE_steps'].apply(len)\n", + "df['CE'] = df['CE_steps'].apply(lambda x: sum(x) / len(x))" ] }, { @@ -143,7 +141,7 @@ } ], "source": [ - "df[df[\"Num_steps\"] == 6][\"CE_steps\"].head(2)" + "df[df['Num_steps'] == 6]['CE_steps'].head(2)" ] }, { @@ -167,29 +165,29 @@ "\n", "set_light_theme()\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(15, 5), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=True\n", + " 1, 2, figsize=(15, 5), gridspec_kw={'width_ratios': [1, 1]}, sharey=True\n", ")\n", "\n", "\n", "sns.histplot(\n", " ax=axs[0],\n", - " data=df[df[\"SPECTYPE\"].isin([\"SINGLE_BEST_SCAN\", \"SAME_ENERGY\"])],\n", - " x=\"CE\",\n", - " hue=\"origin\",\n", - " multiple=\"stack\",\n", + " data=df[df['SPECTYPE'].isin(['SINGLE_BEST_SCAN', 'SAME_ENERGY'])],\n", + " x='CE',\n", + " hue='origin',\n", + " multiple='stack',\n", " binwidth=2,\n", ")\n", "sns.histplot(\n", " ax=axs[1],\n", - " data=df[df[\"SPECTYPE\"].isin([\"ALL_ENERGIES\", \"ALL_MSN_TO_PSEUDO_MS2\"])],\n", - " x=\"CE\",\n", - " hue=\"origin\",\n", - " multiple=\"stack\",\n", + " data=df[df['SPECTYPE'].isin(['ALL_ENERGIES', 'ALL_MSN_TO_PSEUDO_MS2'])],\n", + " x='CE',\n", + " hue='origin',\n", + " multiple='stack',\n", " binwidth=2,\n", ")\n", - "axs[0].set_title(\"Single Energy\")\n", - "axs[0].legend(\"\")\n", - "axs[1].set_title(\"Multiple Energies (Average)\")\n", + "axs[0].set_title('Single Energy')\n", + "axs[0].legend('')\n", + "axs[1].set_title('Multiple Energies (Average)')\n", "plt.show()" ] }, @@ -211,32 +209,32 @@ ], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(10, 5), gridspec_kw={\"width_ratios\": [1.5, 1]}, sharey=True\n", + " 1, 2, figsize=(10, 5), gridspec_kw={'width_ratios': [1.5, 1]}, sharey=True\n", ")\n", "# plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", "sns.histplot(\n", " ax=axs[0],\n", - " data=df[df[\"IONMODE\"] == \"Positive\"],\n", - " x=\"ADDUCT\",\n", + " data=df[df['IONMODE'] == 'Positive'],\n", + " x='ADDUCT',\n", " palette=magma(7),\n", - " hue=\"ADDUCT\",\n", - " edgecolor=\"black\",\n", - " stat=\"density\",\n", + " hue='ADDUCT',\n", + " edgecolor='black',\n", + " stat='density',\n", ")\n", - "axs[0].tick_params(axis=\"x\", rotation=60)\n", - "axs[0].set_xlabel(\"\")\n", + "axs[0].tick_params(axis='x', rotation=60)\n", + "axs[0].set_xlabel('')\n", "sns.histplot(\n", " ax=axs[1],\n", - " data=df[df[\"IONMODE\"] == \"Negative\"],\n", - " x=\"ADDUCT\",\n", - " palette=sns.color_palette(\"mako_r\", 4),\n", - " hue=\"ADDUCT\",\n", - " edgecolor=\"black\",\n", - " stat=\"density\",\n", + " data=df[df['IONMODE'] == 'Negative'],\n", + " x='ADDUCT',\n", + " palette=sns.color_palette('mako_r', 4),\n", + " hue='ADDUCT',\n", + " edgecolor='black',\n", + " stat='density',\n", ")\n", - "axs[1].set_xlabel(\"\")\n", - "axs[1].tick_params(axis=\"x\", rotation=60)\n", + "axs[1].set_xlabel('')\n", + "axs[1].tick_params(axis='x', rotation=60)\n", "plt.show()" ] }, @@ -260,12 +258,12 @@ } ], "source": [ - "print(df[\"SPECTYPE\"].value_counts(dropna=False))\n", + "print(df['SPECTYPE'].value_counts(dropna=False))\n", "\n", "filter_spectype = True\n", - "keep_spectypes = [\"SINGLE_BEST_SCAN\", \"SAME_ENERGY\", \"SINGLE_SCAN\"]\n", + "keep_spectypes = ['SINGLE_BEST_SCAN', 'SAME_ENERGY', 'SINGLE_SCAN']\n", "if filter_spectype:\n", - " df = df[df[\"SPECTYPE\"].isin(keep_spectypes)]\n", + " df = df[df['SPECTYPE'].isin(keep_spectypes)]\n", "\n", "# Note that this early filter step speeds up subsequent operations,\n", "# but one may consider to include stepped/merged spectra and specifically model CE steps." @@ -277,9 +275,9 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MOL.Metabolite import Metabolite\n", - "from fiora.MOL.constants import PPM\n", "import fiora.MOL.constants as constants\n", + "from fiora.MOL.constants import PPM\n", + "from fiora.MOL.Metabolite import Metabolite\n", "\n", "PPM_NUM: int = 10\n", "TOLERANCE = PPM_NUM * PPM" @@ -291,11 +289,11 @@ "metadata": {}, "outputs": [], "source": [ - "df[\"PPM_num\"] = PPM_NUM\n", - "df[\"ppm_peak_tolerance\"] = TOLERANCE\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", - "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - "_ = df[\"Metabolite\"].apply(\n", + "df['PPM_num'] = PPM_NUM\n", + "df['ppm_peak_tolerance'] = TOLERANCE\n", + "df['Metabolite'] = df['SMILES'].apply(Metabolite)\n", + "df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", + "_ = df['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(memory_safe=False)\n", ") # Set memory_safe=False if necessary" ] @@ -309,7 +307,7 @@ "from fiora.MOL.MetaboliteIndex import MetaboliteIndex\n", "\n", "mindex: MetaboliteIndex = MetaboliteIndex()\n", - "mindex.index_metabolites(df[\"Metabolite\"])" + "mindex.index_metabolites(df['Metabolite'])" ] }, { @@ -318,19 +316,19 @@ "metadata": {}, "outputs": [], "source": [ - "h_plus = Chem.MolFromSmiles(\"[H+]\") # h proton\n", + "h_plus = Chem.MolFromSmiles('[H+]') # h proton\n", "\n", "constants.ADDUCT_WEIGHTS.update(\n", " {\n", - " \"[M+2H]-\": Descriptors.ExactMolWt(h_plus)\n", + " '[M+2H]-': Descriptors.ExactMolWt(h_plus)\n", " + 1\n", " * Descriptors.ExactMolWt(\n", - " Chem.MolFromSmiles(\"[H]\")\n", + " Chem.MolFromSmiles('[H]')\n", " ), # 1 proton + 2 neutral hydrogens\n", - " \"[M+3H]-\": Descriptors.ExactMolWt(h_plus)\n", + " '[M+3H]-': Descriptors.ExactMolWt(h_plus)\n", " + 2\n", " * Descriptors.ExactMolWt(\n", - " Chem.MolFromSmiles(\"[H]\")\n", + " Chem.MolFromSmiles('[H]')\n", " ), # 1 proton + 2 neutral hydrogens\n", " }\n", ")\n", @@ -352,9 +350,9 @@ ], "source": [ "list_of_mismatched_ids = mindex.add_fragmentation_trees_to_metabolite_list(\n", - " df[\"Metabolite\"], graph_mismatch_policy=\"recompute\"\n", + " df['Metabolite'], graph_mismatch_policy='recompute'\n", ")\n", - "print(f\"Total number of recomputed trees: {len(list_of_mismatched_ids)}\")" + "print(f'Total number of recomputed trees: {len(list_of_mismatched_ids)}')" ] }, { @@ -372,20 +370,20 @@ } ], "source": [ - "df[\"group_id\"] = df[\"Metabolite\"].apply(lambda x: x.get_id())\n", - "df[\"num_per_group\"] = df[\"group_id\"].map(df[\"group_id\"].value_counts())\n", + "df['group_id'] = df['Metabolite'].apply(lambda x: x.get_id())\n", + "df['num_per_group'] = df['group_id'].map(df['group_id'].value_counts())\n", "for i, data in df.iterrows():\n", - " data[\"Metabolite\"].set_loss_weight(1.0 / data[\"num_per_group\"])\n", - "df[\"loss_weight\"] = df[\"Metabolite\"].apply(lambda x: x.loss_weight)\n", - "print(f\"Number of metabolites in index: {mindex.get_number_of_metabolites()}\")\n", + " data['Metabolite'].set_loss_weight(1.0 / data['num_per_group'])\n", + "df['loss_weight'] = df['Metabolite'].apply(lambda x: x.loss_weight)\n", + "print(f'Number of metabolites in index: {mindex.get_number_of_metabolites()}')\n", "\n", "\n", "def print_df_stats(df):\n", " num_spectra = df.shape[0]\n", - " num_ids = len(df[\"group_id\"].unique())\n", + " num_ids = len(df['group_id'].unique())\n", "\n", " print(\n", - " f\"Dataframe stats: {num_spectra} spectra covering {num_ids} unique structures\"\n", + " f'Dataframe stats: {num_spectra} spectra covering {num_ids} unique structures'\n", " )\n", "\n", "\n", @@ -399,9 +397,9 @@ "outputs": [], "source": [ "_ = df.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"],\n", - " x[\"peaks\"][\"intensity\"],\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'],\n", + " x['peaks']['intensity'],\n", " tolerance=TOLERANCE,\n", " match_stats_only=True,\n", " ),\n", @@ -426,18 +424,18 @@ } ], "source": [ - "df[\"coverage\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", + "df['coverage'] = df['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", "sns.histplot(\n", - " data=df[df[\"SPECTYPE\"].isin([\"SINGLE_BEST_SCAN\", \"SAME_ENERGY\"])],\n", - " x=\"coverage\",\n", - " hue=\"CE\",\n", - " palette=\"magma_r\",\n", - " edgecolor=\"black\",\n", - " multiple=\"stack\",\n", - " stat=\"density\",\n", - " hue_norm=(df[\"CE\"].min(), df[\"CE\"].max()),\n", + " data=df[df['SPECTYPE'].isin(['SINGLE_BEST_SCAN', 'SAME_ENERGY'])],\n", + " x='coverage',\n", + " hue='CE',\n", + " palette='magma_r',\n", + " edgecolor='black',\n", + " multiple='stack',\n", + " stat='density',\n", + " hue_norm=(df['CE'].min(), df['CE'].max()),\n", ")\n", - "plt.xlabel(\"Peak intensity covered by single fragmentation events\")\n", + "plt.xlabel('Peak intensity covered by single fragmentation events')\n", "\n", "plt.show()" ] @@ -459,18 +457,18 @@ } ], "source": [ - "df[\"Num peaks\"] = df[\"Num peaks\"].astype(int)\n", + "df['Num peaks'] = df['Num peaks'].astype(int)\n", "sns.histplot(\n", " df,\n", - " x=\"Num peaks\",\n", + " x='Num peaks',\n", " color=magma(6)[2],\n", - " multiple=\"stack\",\n", - " stat=\"density\",\n", + " multiple='stack',\n", + " stat='density',\n", " binwidth=5,\n", - " edgecolor=\"white\",\n", + " edgecolor='white',\n", " linewidth=0.5,\n", ")\n", - "plt.xlabel(\"Num of peaks\")\n", + "plt.xlabel('Num of peaks')\n", "plt.xlim(0, 250)\n", "plt.show()" ] @@ -495,11 +493,11 @@ "metadata": {}, "outputs": [], "source": [ - "cast_float = [\"PEPMASS\", \"RTINSECONDS\"]\n", + "cast_float = ['PEPMASS', 'RTINSECONDS']\n", "df[cast_float] = df[cast_float].astype(float)\n", - "df[\"ionization\"] = \"ESI\"\n", - "df[\"instrument\"] = \"HCD\"\n", - "df[\"Precursor_type\"] = df[\"ADDUCT\"]" + "df['ionization'] = 'ESI'\n", + "df['instrument'] = 'HCD'\n", + "df['Precursor_type'] = df['ADDUCT']" ] }, { @@ -522,64 +520,64 @@ "from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder\n", "\n", "metadata_key_map = {\n", - " \"name\": \"NAME\",\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"instrument\",\n", - " \"ionization\": \"ionization\",\n", - " \"precursor_mz\": \"PEPMASS\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RTINSECONDS\",\n", - " \"ce_steps\": \"CE_steps\",\n", + " 'name': 'NAME',\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'instrument',\n", + " 'ionization': 'ionization',\n", + " 'precursor_mz': 'PEPMASS',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RTINSECONDS',\n", + " 'ce_steps': 'CE_steps',\n", "}\n", "\n", "filter_spectra = True\n", "CE_upper_limit = 100.0\n", "weight_upper_limit = 1000.0\n", - "allowed_precursor_modes = [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"]\n", + "allowed_precursor_modes = ['[M+H]+', '[M-H]-', '[M]+', '[M]-']\n", "\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "setup_encoder = CovariateFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", " ]\n", ")\n", "rt_encoder = CovariateFeatureEncoder(\n", - " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"]\n", + " feature_list=['molecular_weight', 'precursor_mode', 'instrument']\n", ")\n", "\n", "if filter_spectra:\n", - " setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - " setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - " rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + " setup_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + " setup_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + " rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", "\n", - "df[\"summary\"] = df.apply(\n", + "df['summary'] = df.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", ")\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", ")\n", "\n", "if filter_spectra:\n", - " df = df[df[\"ADDUCT\"].isin(allowed_precursor_modes)]\n", + " df = df[df['ADDUCT'].isin(allowed_precursor_modes)]\n", " num_ori = df.shape[0]\n", - " correct_energy = df[\"Metabolite\"].apply(\n", + " correct_energy = df['Metabolite'].apply(\n", " lambda x: (\n", - " x.metadata[\"collision_energy\"] <= CE_upper_limit\n", - " and x.metadata[\"collision_energy\"] > 1\n", + " x.metadata['collision_energy'] <= CE_upper_limit\n", + " and x.metadata['collision_energy'] > 1\n", " )\n", " )\n", " df = df[correct_energy]\n", - " correct_weight = df[\"Metabolite\"].apply(\n", - " lambda x: x.metadata[\"molecular_weight\"] <= weight_upper_limit\n", + " correct_weight = df['Metabolite'].apply(\n", + " lambda x: x.metadata['molecular_weight'] <= weight_upper_limit\n", " )\n", " df = df[correct_weight]\n", - " print(f\"Filtering spectra ({num_ori}) down to {df.shape[0]}\")\n", + " print(f'Filtering spectra ({num_ori}) down to {df.shape[0]}')\n", " # df = df[df[\"SPECTYPE\"] != \"ALL_ENERGIES\"]\n", " # print(df[\"Precursor_type\"].value_counts())" ] @@ -605,7 +603,7 @@ } ], "source": [ - "df[\"SPECTYPE\"].value_counts(dropna=False)" + "df['SPECTYPE'].value_counts(dropna=False)" ] }, { @@ -627,20 +625,20 @@ } ], "source": [ - "s_single_best = set(df[df[\"SPECTYPE\"] == \"SINGLE_BEST_SCAN\"].group_id.unique())\n", - "s_single = set(df[df[\"SPECTYPE\"] == \"SINGLE_SCAN\"].group_id.unique())\n", - "s_same = set(df[df[\"SPECTYPE\"] == \"SAME_ENERGY\"].group_id.unique())\n", - "s_all = set(df[df[\"SPECTYPE\"] == \"ALL_ENERGIES\"].group_id.unique())\n", - "s_pseudo = set(df[df[\"SPECTYPE\"] == \"ALL_MSN_TO_PSEUDO_MS2\"].group_id.unique())\n", - "print(\"Num values in s_single_best:\", len(s_single_best))\n", - "print(\"Num values in s_single:\", len(s_single))\n", - "print(\"Num values in s_same:\", len(s_same))\n", + "s_single_best = set(df[df['SPECTYPE'] == 'SINGLE_BEST_SCAN'].group_id.unique())\n", + "s_single = set(df[df['SPECTYPE'] == 'SINGLE_SCAN'].group_id.unique())\n", + "s_same = set(df[df['SPECTYPE'] == 'SAME_ENERGY'].group_id.unique())\n", + "s_all = set(df[df['SPECTYPE'] == 'ALL_ENERGIES'].group_id.unique())\n", + "s_pseudo = set(df[df['SPECTYPE'] == 'ALL_MSN_TO_PSEUDO_MS2'].group_id.unique())\n", + "print('Num values in s_single_best:', len(s_single_best))\n", + "print('Num values in s_single:', len(s_single))\n", + "print('Num values in s_same:', len(s_same))\n", "print(\n", - " \"Num values in combined single and same:\",\n", + " 'Num values in combined single and same:',\n", " len(s_single_best.union(s_same).union(s_single)),\n", ")\n", - "print(\"Num values in s_all:\", len(s_all))\n", - "print(\"Num values in s_pseudo:\", len(s_pseudo))" + "print('Num values in s_all:', len(s_all))\n", + "print('Num values in s_pseudo:', len(s_pseudo))" ] }, { @@ -672,7 +670,7 @@ } ], "source": [ - "df[\"Metabolite\"].iloc[0].draw()" + "df['Metabolite'].iloc[0].draw()" ] }, { @@ -695,15 +693,15 @@ "import networkx as nx\n", "from matplotlib import pyplot as plt\n", "\n", - "df[\"Metabolite\"].iloc[0].Graph\n", + "df['Metabolite'].iloc[0].Graph\n", "\n", - "graph = df[\"Metabolite\"].iloc[0].Graph\n", + "graph = df['Metabolite'].iloc[0].Graph\n", "plt.figure(figsize=(10, 10))\n", "nx.draw(\n", " graph,\n", " with_labels=True,\n", - " node_color=\"lightblue\",\n", - " edge_color=\"gray\",\n", + " node_color='lightblue',\n", + " edge_color='gray',\n", " node_size=500,\n", " font_size=10,\n", ")\n", @@ -746,7 +744,7 @@ } ], "source": [ - "df[\"Metabolite\"].iloc[0].fragmentation_tree.get_all_fragments()" + "df['Metabolite'].iloc[0].fragmentation_tree.get_all_fragments()" ] }, { @@ -774,8 +772,8 @@ ], "source": [ "f = 3\n", - "print(df[\"Metabolite\"].iloc[0].fragmentation_tree.get_all_fragments()[f])\n", - "df[\"Metabolite\"].iloc[0].fragmentation_tree.get_all_fragments()[f].subgraphs" + "print(df['Metabolite'].iloc[0].fragmentation_tree.get_all_fragments()[f])\n", + "df['Metabolite'].iloc[0].fragmentation_tree.get_all_fragments()[f].subgraphs" ] }, { @@ -822,7 +820,7 @@ } ], "source": [ - "df[\"Metabolite\"].iloc[0].fragmentation_tree" + "df['Metabolite'].iloc[0].fragmentation_tree" ] }, { @@ -867,8 +865,8 @@ "outputs": [], "source": [ "# TODO run this\n", - "ring_condition = df[\"Metabolite\"].apply(\n", - " lambda x: (x.match_stats[\"coverage\"] < 0.5) & bool(x.ring_proportion > 0.8)\n", + "ring_condition = df['Metabolite'].apply(\n", + " lambda x: (x.match_stats['coverage'] < 0.5) & bool(x.ring_proportion > 0.8)\n", ")\n", "df_rings = df[ring_condition]\n", "\n", @@ -893,7 +891,7 @@ ], "source": [ "for i, d in df_rings.head(1).iterrows():\n", - " d[\"Metabolite\"].draw()\n", + " d['Metabolite'].draw()\n", " plt.show()" ] }, @@ -905,8 +903,8 @@ "source": [ "save_rings: bool = False\n", "if save_rings:\n", - " path: str = f\"{home}/data/metabolites/preprocessed/rings_msnlib.csv\"\n", - " print(f\"Saving ring dataframe to {path}\")\n", + " path: str = f'{home}/data/metabolites/preprocessed/rings_msnlib.csv'\n", + " print(f'Saving ring dataframe to {path}')\n", " df_rings.to_csv(path)" ] }, @@ -929,16 +927,16 @@ "\n", "# Hard filter conditions that must be fulfilled\n", "hard_filters: Dict[str, int] = {\n", - " \"min_peaks\": 2,\n", - " \"min_coverage\": 0.5,\n", - " \"max_precursor_intensity\": 0.9,\n", + " 'min_peaks': 2,\n", + " 'min_coverage': 0.5,\n", + " 'max_precursor_intensity': 0.9,\n", "}\n", "\n", "# Soft conditions where at least one must be met\n", "soft_filters: Dict[str, int] = {\n", - " \"desired_peaks\": 4,\n", - " \"desired_coverage\": 0.75,\n", - " \"desired_peak_percentage\": 0.5, # Proportion of peaks covered by the fragmentation\n", + " 'desired_peaks': 4,\n", + " 'desired_coverage': 0.75,\n", + " 'desired_peak_percentage': 0.5, # Proportion of peaks covered by the fragmentation\n", "}\n", "\n", "hard_filters_drops: Dict[str, list] = {key: [] for key in hard_filters.keys()}\n", @@ -1010,15 +1008,15 @@ } ], "source": [ - "from matplotlib_venn import venn3\n", "import matplotlib.pyplot as plt\n", + "from matplotlib_venn import venn3\n", "\n", "# Extract the sets and labels from the dictionary\n", "sets = [set(hard_filters_drops[key]) for key in hard_filters_drops]\n", "key_to_label = {\n", - " \"min_peaks\": \"Min Peaks\",\n", - " \"min_coverage\": \"Min Coverage\",\n", - " \"max_precursor_intensity\": \"Max Precursor Intensity\",\n", + " 'min_peaks': 'Min Peaks',\n", + " 'min_coverage': 'Min Coverage',\n", + " 'max_precursor_intensity': 'Max Precursor Intensity',\n", "}\n", "labels = [key_to_label[key] for key in hard_filters_drops.keys()]\n", "\n", @@ -1030,11 +1028,11 @@ "for subset in venn.patches:\n", " if subset: # Check if the subset exists (not None)\n", " subset.set_linewidth(2) # Set line thickness\n", - " subset.set_edgecolor(\"black\") # Set edge color to black\n", + " subset.set_edgecolor('black') # Set edge color to black\n", "\n", "\n", "# Add a title\n", - "plt.title(\"Overlap of Spectra Removed by Hard Filters\")\n", + "plt.title('Overlap of Spectra Removed by Hard Filters')\n", "plt.show()" ] }, @@ -1099,11 +1097,11 @@ } ], "source": [ - "from matplotlib_venn import venn3\n", "import matplotlib.pyplot as plt\n", + "from matplotlib_venn import venn3\n", "\n", "# Convert spectral indices to group IDs\n", - "group_ids = df[\"group_id\"].unique()\n", + "group_ids = df['group_id'].unique()\n", "\n", "# Identify compounds completely removed by hard filters\n", "hard_filter_indices = set()\n", @@ -1113,7 +1111,7 @@ "compounds_removed_by_hard_filters = {\n", " group_id\n", " for group_id in group_ids\n", - " if all(idx in hard_filter_indices for idx in df[df[\"group_id\"] == group_id].index)\n", + " if all(idx in hard_filter_indices for idx in df[df['group_id'] == group_id].index)\n", "}\n", "\n", "# Identify compounds removed by soft filters (or hard filters)\n", @@ -1123,7 +1121,7 @@ " for group_id in group_ids\n", " if all(\n", " idx in hard_filter_indices or idx in soft_filter_indices\n", - " for idx in df[df[\"group_id\"] == group_id].index\n", + " for idx in df[df['group_id'] == group_id].index\n", " )\n", "}\n", "\n", @@ -1138,17 +1136,17 @@ " compounds_removed_by_hard_filters,\n", " compounds_removed_by_soft_filters,\n", " ],\n", - " (\"All Compounds\", \"Removed By Hard Filter\", \"Removed By Soft Filters\"),\n", + " ('All Compounds', 'Removed By Hard Filter', 'Removed By Soft Filters'),\n", ")\n", "\n", "# Customize the Venn diagram with thicker lines and black edges\n", "for subset in venn.patches:\n", " if subset: # Check if the subset exists (not None)\n", " subset.set_linewidth(2) # Set line thickness\n", - " subset.set_edgecolor(\"black\") # Set edge color to black\n", + " subset.set_edgecolor('black') # Set edge color to black\n", "\n", "# Add a title\n", - "plt.title(\"Overlap of Compounds Removed by Hard and Soft Filters\")\n", + "plt.title('Overlap of Compounds Removed by Hard and Soft Filters')\n", "plt.show()" ] }, @@ -1208,13 +1206,13 @@ "source": [ "sns.histplot(\n", " df[\n", - " df[\"SPECTYPE\"].isin(\n", - " [\"SAME_ENERGY\", \"SINGLE_BEST_SCAN\"],\n", + " df['SPECTYPE'].isin(\n", + " ['SAME_ENERGY', 'SINGLE_BEST_SCAN'],\n", " )\n", " ],\n", - " x=\"coverage\",\n", - " hue=\"CE\",\n", - " multiple=\"stack\",\n", + " x='coverage',\n", + " hue='CE',\n", + " multiple='stack',\n", ")" ] }, @@ -1247,15 +1245,16 @@ ], "source": [ "from collections import defaultdict\n", + "\n", "from fiora.MOL.constants import DEFAULT_MODES\n", "\n", "\n", "def count_ion_mode_matches(df):\n", " ion_mode_counts = defaultdict(float)\n", "\n", - " for metabolite in df[\"Metabolite\"]:\n", + " for metabolite in df['Metabolite']:\n", " for peak_data in metabolite.peak_matches.values():\n", - " ion_modes = peak_data[\"ion_modes\"]\n", + " ion_modes = peak_data['ion_modes']\n", " num_modes = len(ion_modes)\n", " for mode, _ in ion_modes: # mode is the ion mode string\n", " if mode in DEFAULT_MODES:\n", @@ -1267,24 +1266,24 @@ "\n", "\n", "ion_mode_counts = count_ion_mode_matches(df)\n", - "ion_mode_df = pd.DataFrame(list(ion_mode_counts.items()), columns=[\"Ion Mode\", \"Count\"])\n", + "ion_mode_df = pd.DataFrame(list(ion_mode_counts.items()), columns=['Ion Mode', 'Count'])\n", "\n", "# Create a figure with two subplots\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", "\n", - "sns.barplot(data=ion_mode_df, x=\"Ion Mode\", y=\"Count\", ax=axes[0], palette=\"viridis\")\n", - "axes[0].set_title(\"Ion Mode Counts (Bar Plot)\")\n", - "axes[0].set_xlabel(\"Ion Mode\")\n", - "axes[0].set_ylabel(\"Count\")\n", - "axes[0].tick_params(axis=\"x\", rotation=45)\n", + "sns.barplot(data=ion_mode_df, x='Ion Mode', y='Count', ax=axes[0], palette='viridis')\n", + "axes[0].set_title('Ion Mode Counts (Bar Plot)')\n", + "axes[0].set_xlabel('Ion Mode')\n", + "axes[0].set_ylabel('Count')\n", + "axes[0].tick_params(axis='x', rotation=45)\n", "axes[1].pie(\n", - " ion_mode_df[\"Count\"],\n", - " labels=ion_mode_df[\"Ion Mode\"],\n", - " autopct=\"%1.1f%%\",\n", + " ion_mode_df['Count'],\n", + " labels=ion_mode_df['Ion Mode'],\n", + " autopct='%1.1f%%',\n", " startangle=90,\n", - " colors=sns.color_palette(\"viridis\", len(ion_mode_df)),\n", + " colors=sns.color_palette('viridis', len(ion_mode_df)),\n", ")\n", - "axes[1].set_title(\"Ion Mode Distribution (Pie Chart)\")\n", + "axes[1].set_title('Ion Mode Distribution (Pie Chart)')\n", "\n", "# Adjust layout and show the plots\n", "plt.tight_layout()\n", @@ -1308,40 +1307,39 @@ "source": [ "from fiora.IO.LibraryLoader import LibraryLoader\n", "\n", - "\n", "train_ids, va_ids, test_ids = [], [], []\n", "\n", "L = LibraryLoader()\n", - "casmi16_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv\"\n", - "casmi22_path = f\"{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv\"\n", + "casmi16_path = f'{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv'\n", + "casmi22_path = f'{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv'\n", "df_merged = L.load_from_csv(\n", - " f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\"\n", + " f'{home}/data/metabolites/preprocessed/datasplits_Jan24.csv'\n", ")\n", "df_cas = pd.read_csv(casmi16_path, index_col=[0], low_memory=False)\n", "df_cas22 = pd.read_csv(casmi22_path, index_col=[0], low_memory=False)\n", "df_cast = pd.read_csv(\n", - " f\"{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv\",\n", + " f'{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv',\n", " index_col=[0],\n", " low_memory=False,\n", ")\n", "\n", "other_dfs = {\n", - " \"train\": df_merged[df_merged[\"dataset\"] == \"training\"].drop_duplicates(\n", - " subset=[\"group_id\"]\n", + " 'train': df_merged[df_merged['dataset'] == 'training'].drop_duplicates(\n", + " subset=['group_id']\n", " ),\n", - " \"val\": df_merged[df_merged[\"dataset\"] == \"validation\"].drop_duplicates(\n", - " subset=[\"group_id\"]\n", + " 'val': df_merged[df_merged['dataset'] == 'validation'].drop_duplicates(\n", + " subset=['group_id']\n", " ),\n", - " \"test\": pd.concat(\n", + " 'test': pd.concat(\n", " [\n", - " df_merged[df_merged[\"dataset\"] == \"test\"].drop_duplicates(\n", - " subset=[\"group_id\"]\n", + " df_merged[df_merged['dataset'] == 'test'].drop_duplicates(\n", + " subset=['group_id']\n", " ),\n", " df_cas,\n", " df_cast,\n", " df_cas22,\n", " ]\n", - " ).drop_duplicates(subset=[\"SMILES\"]),\n", + " ).drop_duplicates(subset=['SMILES']),\n", "}" ] }, @@ -1351,15 +1349,15 @@ "metadata": {}, "outputs": [], "source": [ - "lookup_table = {\"train\": set(), \"val\": set(), \"test\": set()}\n", + "lookup_table = {'train': set(), 'val': set(), 'test': set()}\n", "for key, df_x in other_dfs.items():\n", - " df_x[\"Metabolite\"] = df_x[\"SMILES\"].apply(Metabolite)\n", + " df_x['Metabolite'] = df_x['SMILES'].apply(Metabolite)\n", "\n", " for i, data in df_x.iterrows():\n", " lookup_table[key].add(\n", " (\n", - " data[\"Metabolite\"].ExactMolWeight,\n", - " data[\"Metabolite\"].morganFingerCountOnes,\n", + " data['Metabolite'].ExactMolWeight,\n", + " data['Metabolite'].morganFingerCountOnes,\n", " )\n", " )" ] @@ -1373,27 +1371,27 @@ "train, val, test = [], [], []\n", "\n", "\n", - "for id in df[\"group_id\"].unique():\n", - " metabolite: Metabolite = df[df[\"group_id\"] == id].iloc[0][\"Metabolite\"]\n", + "for id in df['group_id'].unique():\n", + " metabolite: Metabolite = df[df['group_id'] == id].iloc[0]['Metabolite']\n", " fast_identifiers = (metabolite.ExactMolWeight, metabolite.morganFingerCountOnes)\n", " found_match = False\n", "\n", - " if fast_identifiers in lookup_table[\"train\"]:\n", - " for i, data in other_dfs[\"train\"].iterrows():\n", - " other_metabolite = data[\"Metabolite\"]\n", + " if fast_identifiers in lookup_table['train']:\n", + " for i, data in other_dfs['train'].iterrows():\n", + " other_metabolite = data['Metabolite']\n", " if metabolite == other_metabolite:\n", " train.append(id)\n", " found_match = True\n", " break\n", - " if not found_match and fast_identifiers in lookup_table[\"val\"]:\n", - " for i, data in other_dfs[\"val\"].iterrows():\n", - " other_metabolite = data[\"Metabolite\"]\n", + " if not found_match and fast_identifiers in lookup_table['val']:\n", + " for i, data in other_dfs['val'].iterrows():\n", + " other_metabolite = data['Metabolite']\n", " if metabolite == other_metabolite:\n", " val.append(id)\n", " break\n", - " if not found_match and fast_identifiers in lookup_table[\"test\"]:\n", - " for i, data in other_dfs[\"test\"].iterrows():\n", - " other_metabolite = data[\"Metabolite\"]\n", + " if not found_match and fast_identifiers in lookup_table['test']:\n", + " for i, data in other_dfs['test'].iterrows():\n", + " other_metabolite = data['Metabolite']\n", " if metabolite == other_metabolite:\n", " test.append(id)\n", " break" @@ -1415,9 +1413,9 @@ ], "source": [ "print(\n", - " f\"Preset compounds assigned to datasplits: {len(train)=} {len(val)=} {len(test)=}\"\n", + " f'Preset compounds assigned to datasplits: {len(train)=} {len(val)=} {len(test)=}'\n", ")\n", - "print(f\"{train[:5]=} {val[-5:]=} {test[5:10]=}\")" + "print(f'{train[:5]=} {val[-5:]=} {test[5:10]=}')" ] }, { @@ -1447,7 +1445,7 @@ "metadata": {}, "outputs": [], "source": [ - "group_ids = df[\"group_id\"].astype(int)\n", + "group_ids = df['group_id'].astype(int)\n", "keys = np.unique(group_ids)\n", "num_keys = len(keys)\n", "mask = ~np.isin(keys, train + val + test)\n", @@ -1468,18 +1466,18 @@ "test = np.concatenate((np.array(test), test_add))\n", "\n", "\n", - "df[\"dataset\"] = df[\"group_id\"].apply(\n", + "df['dataset'] = df['group_id'].apply(\n", " lambda x: (\n", - " \"training\"\n", + " 'training'\n", " if x in train\n", - " else \"validation\"\n", + " else 'validation'\n", " if x in val\n", - " else \"test\"\n", + " else 'test'\n", " if x in test\n", - " else \"VALUE ERROR\"\n", + " else 'VALUE ERROR'\n", " )\n", ")\n", - "df[\"datasplit\"] = df[\"dataset\"]" + "df['datasplit'] = df['dataset']" ] }, { @@ -1500,7 +1498,7 @@ } ], "source": [ - "print(f\"Unique compounds in each {df.groupby('datasplit')['group_id'].nunique()}\")" + "print(f'Unique compounds in each {df.groupby(\"datasplit\")[\"group_id\"].nunique()}')" ] }, { @@ -1510,10 +1508,10 @@ "outputs": [], "source": [ "# Readjust loss weight\n", - "df[\"num_per_group\"] = df[\"group_id\"].map(df[\"group_id\"].value_counts())\n", + "df['num_per_group'] = df['group_id'].map(df['group_id'].value_counts())\n", "for i, data in df.iterrows():\n", - " data[\"Metabolite\"].set_loss_weight(1.0 / data[\"num_per_group\"])\n", - "df[\"loss_weight\"] = df[\"Metabolite\"].apply(lambda x: x.loss_weight)" + " data['Metabolite'].set_loss_weight(1.0 / data['num_per_group'])\n", + "df['loss_weight'] = df['Metabolite'].apply(lambda x: x.loss_weight)" ] }, { @@ -1541,7 +1539,7 @@ } ], "source": [ - "raise KeyboardInterrupt(\"Stop before saving\")" + "raise KeyboardInterrupt('Stop before saving')" ] }, { @@ -1560,8 +1558,8 @@ "source": [ "save_df: bool = False\n", "if save_df:\n", - " path: str = f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_v7_Sep25.csv\" # Save with merge spectra (4)\n", - " print(f\"Saving datasplits to {path}\")\n", + " path: str = f'{home}/data/metabolites/preprocessed/datasplits_msnlib_v7_Sep25.csv' # Save with merge spectra (4)\n", + " print(f'Saving datasplits to {path}')\n", " df.to_csv(path)" ] }, @@ -1589,31 +1587,31 @@ } ], "source": [ - "for stat in df.iloc[0][\"Metabolite\"].match_stats.keys():\n", - " df[stat] = df[\"Metabolite\"].apply(lambda x: x.match_stats[stat])\n", + "for stat in df.iloc[0]['Metabolite'].match_stats.keys():\n", + " df[stat] = df['Metabolite'].apply(lambda x: x.match_stats[stat])\n", "\n", "fig, axs = plt.subplots(1, 2, figsize=(12.8, 6.4), sharey=False)\n", "\n", - "fig.suptitle(f\"Identified peak-fragment matches and number conflicts\")\n", + "fig.suptitle('Identified peak-fragment matches and number conflicts')\n", "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", "sns.histplot(\n", " ax=axs[0],\n", " data=df,\n", - " x=\"num_peak_matches\",\n", + " x='num_peak_matches',\n", " color=color_palette[0],\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " bins=range(0, 20, 1),\n", ")\n", "# axs[0].set_ylim(-0.5, 10)\n", - "axs[0].set_ylabel(\"peaks identified\")\n", + "axs[0].set_ylabel('peaks identified')\n", "\n", "\n", "sns.histplot(\n", - " ax=axs[1], data=df, x=\"num_fragment_conflicts\", color=color_palette[3], binwidth=1\n", + " ax=axs[1], data=df, x='num_fragment_conflicts', color=color_palette[3], binwidth=1\n", ")\n", "# axs[2].set_ylim(-0.5, 1000)\n", - "axs[1].set_xlabel(\"conflicts\")\n", - "axs[1].set_ylabel(\"\")\n", + "axs[1].set_xlabel('conflicts')\n", + "axs[1].set_ylabel('')\n", "\n", "plt.show()" ] @@ -1636,7 +1634,7 @@ ], "source": [ "# df[\"RETENTIONTIME\"] = df[\"RTINSECONDS\"].astype(float)\n", - "sns.displot(df, x=\"RTINSECONDS\", kde=True, binwidth=0.5, hue=\"ADDUCT\")\n", + "sns.displot(df, x='RTINSECONDS', kde=True, binwidth=0.5, hue='ADDUCT')\n", "plt.show()" ] }, @@ -1662,11 +1660,11 @@ "sns.histplot(\n", " ax=ax,\n", " data=df,\n", - " x=\"RTINSECONDS\",\n", - " hue=\"origin\",\n", - " multiple=\"stack\",\n", + " x='RTINSECONDS',\n", + " hue='origin',\n", + " multiple='stack',\n", " binwidth=1,\n", - " stat=\"probability\",\n", + " stat='probability',\n", ")\n", "plt.show()" ] @@ -1691,9 +1689,9 @@ "fig, ax = plt.subplots(1, 1, figsize=(12.8, 6.4))\n", "\n", "sns.kdeplot(\n", - " ax=ax, data=df, x=\"RTINSECONDS\", hue=\"origin\", multiple=\"fill\", common_norm=False\n", + " ax=ax, data=df, x='RTINSECONDS', hue='origin', multiple='fill', common_norm=False\n", ")\n", - "ax.legend(bbox_to_anchor=(1.5, 0.8), labels=df[\"origin\"].unique())\n", + "ax.legend(bbox_to_anchor=(1.5, 0.8), labels=df['origin'].unique())\n", "plt.show()" ] } diff --git a/lib_loader/nist_library_loader.ipynb b/lib_loader/nist_library_loader.ipynb index 8c66cd2..b085dad 100644 --- a/lib_loader/nist_library_loader.ipynb +++ b/lib_loader/nist_library_loader.ipynb @@ -32,54 +32,46 @@ "source": [ "import sys\n", "\n", - "print(f\"Working with Python {sys.version}\")\n", + "print(f'Working with Python {sys.version}')\n", "\n", - "import numpy as np\n", - "import pandas as pd\n", + "import collections\n", "import importlib\n", + "import os\n", + "import time\n", "\n", "# import swifter\n", "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import collections\n", - "import time\n", - "import os\n", - "from rdkit import Chem\n", - "from rdkit.Chem import AllChem\n", - "from rdkit.Chem import Draw\n", + "import numpy as np\n", + "import pandas as pd\n", "import rdkit.Chem.Descriptors as Descriptors\n", - "from rdkit.Chem import PandasTools\n", - "#!pip install spektral\n", - "\n", + "import seaborn as sns\n", "\n", + "#!pip install spektral\n", "# Deep Learning\n", "import sklearn\n", "\n", - "# import spektral\n", - "from sklearn.model_selection import train_test_split\n", - "\n", "# Keras\n", - "from sklearn.model_selection import train_test_split\n", - "\n", "# import stellargraph as sg\n", - "from rdkit import RDLogger\n", + "from rdkit import Chem, RDLogger\n", + "from rdkit.Chem import AllChem, Draw, PandasTools\n", "\n", + "# import spektral\n", + "from sklearn.model_selection import train_test_split\n", "\n", "#\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", + "home = expanduser('~')\n", + "import fiora.IO.molReader as molReader\n", "import fiora.IO.mspReader as mspReader\n", "import fiora.visualization.spectrum_visualizer as sv\n", - "import fiora.IO.molReader as molReader\n", - "\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", "\n", - "caffeine_smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", + "caffeine_smiles = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'\n", "caffeine_mol = Chem.MolFromSmiles(caffeine_smiles)\n", "\n", "caffeine_mol" @@ -91,8 +83,8 @@ "metadata": {}, "outputs": [], "source": [ - "library_name = \"nist_msms\"\n", - "library_directory = f\"{home}/data/metabolites/NIST17/msp/nist_msms/\"\n", + "library_name = 'nist_msms'\n", + "library_directory = f'{home}/data/metabolites/NIST17/msp/nist_msms/'\n", "!ls $library_directory" ] }, @@ -128,13 +120,13 @@ } ], "source": [ - "nist_msp = mspReader.read(library_directory + library_name + \".MSP\")\n", + "nist_msp = mspReader.read(library_directory + library_name + '.MSP')\n", "df_nist = pd.DataFrame(nist_msp)\n", "\n", "# df_nist['mol'] = df_nist['SMILES'].apply(Chem.MolFromSmiles)\n", "# df_nist.dropna(inplace=True)\n", "print(\n", - " f\"Spectral file loaded with {df_nist.shape[0]} entries and {df_nist.shape[1]} variables\"\n", + " f'Spectral file loaded with {df_nist.shape[0]} entries and {df_nist.shape[1]} variables'\n", ")" ] }, @@ -174,18 +166,18 @@ "outputs": [], "source": [ "# Define figure styles\n", - "color_palette = sns.color_palette(\"magma_r\", 8)\n", + "color_palette = sns.color_palette('magma_r', 8)\n", "sns.set_theme(\n", - " style=\"whitegrid\",\n", + " style='whitegrid',\n", " rc={\n", - " \"axes.edgecolor\": \"black\",\n", - " \"ytick.left\": True,\n", - " \"xtick.bottom\": True,\n", - " \"xtick.color\": \"black\",\n", - " \"axes.spines.bottom\": True,\n", - " \"axes.spines.right\": True,\n", - " \"axes.spines.top\": True,\n", - " \"axes.spines.left\": True,\n", + " 'axes.edgecolor': 'black',\n", + " 'ytick.left': True,\n", + " 'xtick.bottom': True,\n", + " 'xtick.color': 'black',\n", + " 'axes.spines.bottom': True,\n", + " 'axes.spines.right': True,\n", + " 'axes.spines.top': True,\n", + " 'axes.spines.left': True,\n", " },\n", ")" ] @@ -197,25 +189,25 @@ "outputs": [], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 2]}, sharey=True\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 2]}, sharey=True\n", ")\n", "fig.set_tight_layout(False)\n", "for ax in axs:\n", - " ax.tick_params(\"x\", labelrotation=45)\n", + " ax.tick_params('x', labelrotation=45)\n", "\n", "sns.countplot(\n", - " ax=axs[0], data=df_nist, x=\"Spectrum_type\", edgecolor=\"black\", palette=color_palette\n", + " ax=axs[0], data=df_nist, x='Spectrum_type', edgecolor='black', palette=color_palette\n", ")\n", "sns.countplot(\n", " ax=axs[1],\n", " data=df_nist,\n", - " x=\"Precursor_type\",\n", - " edgecolor=\"black\",\n", + " x='Precursor_type',\n", + " edgecolor='black',\n", " palette=color_palette,\n", - " order=df_nist[\"Precursor_type\"].value_counts().iloc[:8].index,\n", + " order=df_nist['Precursor_type'].value_counts().iloc[:8].index,\n", ")\n", "axs[0].set_ylim(0, 500000)\n", - "axs[1].set_ylabel(\"\")\n", + "axs[1].set_ylabel('')\n", "\n", "plt.show()" ] @@ -227,18 +219,18 @@ "outputs": [], "source": [ "# Filters\n", - "df_nist = df_nist[df_nist[\"Spectrum_type\"] == \"MS2\"]\n", - "target_precursor_type = [\"[M+H]+\", \"[M-H]-\", \"[M+H-H2O]+\", \"[M+Na]+\"]\n", + "df_nist = df_nist[df_nist['Spectrum_type'] == 'MS2']\n", + "target_precursor_type = ['[M+H]+', '[M-H]-', '[M+H-H2O]+', '[M+Na]+']\n", "df_nist = df_nist[\n", - " df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in target_precursor_type)\n", + " df_nist['Precursor_type'].apply(lambda ptype: ptype in target_precursor_type)\n", "]\n", "\n", "# Formats\n", - "df_nist[\"PrecursorMZ\"] = df_nist[\"PrecursorMZ\"].astype(\"float\")\n", - "df_nist[\"Num peaks\"] = df_nist[\"Num peaks\"].astype(\"int\")\n", + "df_nist['PrecursorMZ'] = df_nist['PrecursorMZ'].astype('float')\n", + "df_nist['Num peaks'] = df_nist['Num peaks'].astype('int')\n", "\n", "\n", - "print(f\"Spectral file filtered down to {df_nist.shape[0]} entries\")" + "print(f'Spectral file filtered down to {df_nist.shape[0]} entries')" ] }, { @@ -248,27 +240,27 @@ "outputs": [], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 2]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 2]}, sharey=False\n", ")\n", "for ax in axs:\n", - " ax.tick_params(\"x\", labelrotation=45)\n", + " ax.tick_params('x', labelrotation=45)\n", "\n", "sns.boxplot(\n", - " ax=axs[0], data=df_nist, y=\"PrecursorMZ\", palette=color_palette, x=\"Precursor_type\"\n", + " ax=axs[0], data=df_nist, y='PrecursorMZ', palette=color_palette, x='Precursor_type'\n", ")\n", "sns.histplot(\n", " ax=axs[1],\n", " data=df_nist,\n", - " x=\"Num peaks\",\n", + " x='Num peaks',\n", " color=color_palette[7],\n", " fill=True,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", ") # , order=list(range(0,200)))\n", - "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", - "axs[1].set_ylabel(\"\")\n", + "plt.rcParams['patch.force_edgecolor'] = True\n", + "plt.rcParams['axes.edgecolor'] = 'black'\n", + "axs[1].set_ylabel('')\n", "axs[1].set_xlim([0, 100])\n", - "axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", + "axs[1].axvline(x=50, color='red', linestyle='-.')\n", "\n", "plt.show()" ] @@ -281,24 +273,24 @@ "source": [ "# associate MOL structures with MS2 spectra\n", "\n", - "file = library_directory + library_name + \".MOL/\" + \"S\" + x[\"CASNO\"] + \".MOL\"\n", + "file = library_directory + library_name + '.MOL/' + 'S' + x['CASNO'] + '.MOL'\n", "x_mol = molReader.load_MOL(file)\n", "x_mol\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].set_title(x['Name'] + ' structure:\\n' + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", - "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])" + "axs[0].axis('off')\n", + "sv.plot_spectrum(title=x['Name'] + ' MS/MS spectrum', spectrum=x, ax=axs[1])" ] }, { @@ -312,34 +304,34 @@ "# print(df_nist.loc[1474])\n", "\n", "print(\n", - " \"Reading structure information in MOL format from library files (this may take a while)\"\n", + " 'Reading structure information in MOL format from library files (this may take a while)'\n", ")\n", "\n", "\n", "def fetch_mol(data):\n", " file = (\n", - " library_directory + library_name + \".MOL/\" + \"S\" + str(data[\"CASNO\"]) + \".MOL\"\n", + " library_directory + library_name + '.MOL/' + 'S' + str(data['CASNO']) + '.MOL'\n", " )\n", " if not os.path.exists(file):\n", " file = (\n", - " library_directory + library_name + \".MOL/\" + \"ID\" + str(data[\"ID\"]) + \".MOL\"\n", + " library_directory + library_name + '.MOL/' + 'ID' + str(data['ID']) + '.MOL'\n", " )\n", " return molReader.load_MOL(file)\n", "\n", "\n", "df_nist = df_nist[\n", - " ~df_nist[\"InChIKey\"].isnull()\n", + " ~df_nist['InChIKey'].isnull()\n", "] # Drop all without key (Not neccessarily neccesary)\n", - "df_nist[\"MOL\"] = df_nist.apply(fetch_mol, axis=1)\n", + "df_nist['MOL'] = df_nist.apply(fetch_mol, axis=1)\n", "print(\n", - " f\"Successfully interpreted {sum(df_nist['MOL'].notna())} from {df_nist.shape[0]} entries. Dropping the rest.\"\n", + " f'Successfully interpreted {sum(df_nist[\"MOL\"].notna())} from {df_nist.shape[0]} entries. Dropping the rest.'\n", ")\n", "\n", - "df_nist = df_nist[df_nist[\"MOL\"].notna()]\n", - "df_nist[\"SMILES\"] = df_nist[\"MOL\"].apply(Chem.MolToSmiles)\n", - "df_nist[\"InChI\"] = df_nist[\"MOL\"].apply(Chem.MolToInchi)\n", - "df_nist[\"K\"] = df_nist[\"MOL\"].apply(Chem.MolToInchiKey)\n", - "df_nist[\"ExactMolWeight\"] = df_nist[\"MOL\"].apply(Chem.Descriptors.ExactMolWt)" + "df_nist = df_nist[df_nist['MOL'].notna()]\n", + "df_nist['SMILES'] = df_nist['MOL'].apply(Chem.MolToSmiles)\n", + "df_nist['InChI'] = df_nist['MOL'].apply(Chem.MolToInchi)\n", + "df_nist['K'] = df_nist['MOL'].apply(Chem.MolToInchiKey)\n", + "df_nist['ExactMolWeight'] = df_nist['MOL'].apply(Chem.Descriptors.ExactMolWt)" ] }, { @@ -348,14 +340,14 @@ "metadata": {}, "outputs": [], "source": [ - "print(df_nist[df_nist[\"Name\"] == example_entry].iloc[0])\n", - "print(len(df_nist[\"MOL\"].unique()))\n", + "print(df_nist[df_nist['Name'] == example_entry].iloc[0])\n", + "print(len(df_nist['MOL'].unique()))\n", "\n", "\n", "df_nist.shape\n", - "df_nist[\"SMILES\"].isnull().any()\n", + "df_nist['SMILES'].isnull().any()\n", "\n", - "df_nist[df_nist[\"Name\"] == example_entry].iloc[0][\"SMILES\"]" + "df_nist[df_nist['Name'] == example_entry].iloc[0]['SMILES']" ] }, { @@ -364,24 +356,24 @@ "metadata": {}, "outputs": [], "source": [ - "correct_keys = df_nist.apply(lambda x: x[\"InChIKey\"] == x[\"K\"], axis=1)\n", - "s = \"confirmed!\" if correct_keys.all() else \"not confirmed !! Attention!\"\n", + "correct_keys = df_nist.apply(lambda x: x['InChIKey'] == x['K'], axis=1)\n", + "s = 'confirmed!' if correct_keys.all() else 'not confirmed !! Attention!'\n", "print(\n", - " f\"Confirming whether computed and provided InChI-Keys are correct. Result: {s} ({correct_keys.sum() / len(correct_keys):0.2f} correct)\"\n", + " f'Confirming whether computed and provided InChI-Keys are correct. Result: {s} ({correct_keys.sum() / len(correct_keys):0.2f} correct)'\n", ")\n", "half_keys = df_nist.apply(\n", - " lambda x: x[\"InChIKey\"].split(\"-\")[0] == x[\"K\"].split(\"-\")[0], axis=1\n", + " lambda x: x['InChIKey'].split('-')[0] == x['K'].split('-')[0], axis=1\n", ")\n", - "s = \"confirmed!\" if half_keys.all() else \"not confirmed !! Attention!\"\n", + "s = 'confirmed!' if half_keys.all() else 'not confirmed !! Attention!'\n", "print(\n", - " f\"Checking if main layer InChI-Keys are correct. Result: {s} ({half_keys.sum() / len(half_keys):0.3f} correct)\"\n", + " f'Checking if main layer InChI-Keys are correct. Result: {s} ({half_keys.sum() / len(half_keys):0.3f} correct)'\n", ")\n", "\n", - "print(\"Dropping all other.\")\n", - "df_nist[\"matching_key\"] = df_nist.apply(lambda x: x[\"InChIKey\"] == x[\"K\"], axis=1)\n", - "df_nist = df_nist[df_nist[\"matching_key\"]]\n", + "print('Dropping all other.')\n", + "df_nist['matching_key'] = df_nist.apply(lambda x: x['InChIKey'] == x['K'], axis=1)\n", + "df_nist = df_nist[df_nist['matching_key']]\n", "\n", - "print(f\"Shape: {df_nist.shape}\")" + "print(f'Shape: {df_nist.shape}')" ] }, { @@ -390,7 +382,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_nist[\"ExactMolWeight\"] = df_nist[\"MOL\"].apply(Chem.Descriptors.ExactMolWt)" + "df_nist['ExactMolWeight'] = df_nist['MOL'].apply(Chem.Descriptors.ExactMolWt)" ] }, { @@ -419,23 +411,22 @@ "source": [ "MIN_PEAKS = 2\n", "MAX_PEAKS = 30\n", - "PRECURSOR_TYPES = [\"[M+H]+\"]\n", + "PRECURSOR_TYPES = ['[M+H]+']\n", "from modules.MOL.constants import ADDUCT_WEIGHTS\n", "\n", - "\n", - "df_nist = df_nist[df_nist[\"Num peaks\"] > MIN_PEAKS]\n", - "df_nist = df_nist[df_nist[\"Num peaks\"] < MAX_PEAKS] # TODO WHY MAX CUTOFF: REMOVE!!\n", - "df_nist[\"theoretical_precursor_mz\"] = df_nist[\"ExactMolWeight\"] + df_nist[\n", - " \"Precursor_type\"\n", + "df_nist = df_nist[df_nist['Num peaks'] > MIN_PEAKS]\n", + "df_nist = df_nist[df_nist['Num peaks'] < MAX_PEAKS] # TODO WHY MAX CUTOFF: REMOVE!!\n", + "df_nist['theoretical_precursor_mz'] = df_nist['ExactMolWeight'] + df_nist[\n", + " 'Precursor_type'\n", "].map(ADDUCT_WEIGHTS)\n", "df_nist = df_nist[\n", - " df_nist[\"Precursor_type\"].apply(lambda ptype: ptype in PRECURSOR_TYPES)\n", + " df_nist['Precursor_type'].apply(lambda ptype: ptype in PRECURSOR_TYPES)\n", "]\n", - "df_nist[\"precursor_offset\"] = (\n", - " df_nist[\"PrecursorMZ\"] - df_nist[\"theoretical_precursor_mz\"]\n", + "df_nist['precursor_offset'] = (\n", + " df_nist['PrecursorMZ'] - df_nist['theoretical_precursor_mz']\n", ")\n", "\n", - "print(f\"Shape {df_nist.shape}\")" + "print(f'Shape {df_nist.shape}')" ] }, { @@ -445,29 +436,29 @@ "outputs": [], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1.5]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1.5]}, sharey=False\n", ")\n", "for ax in axs:\n", - " ax.tick_params(\"x\", labelrotation=45)\n", + " ax.tick_params('x', labelrotation=45)\n", "\n", "sns.scatterplot(\n", " ax=axs[0],\n", " data=df_nist,\n", - " x=\"precursor_offset\",\n", - " y=\"PrecursorMZ\",\n", + " x='precursor_offset',\n", + " y='PrecursorMZ',\n", " palette=color_palette,\n", ")\n", "sns.histplot(\n", " ax=axs[1],\n", " data=df_nist,\n", - " x=\"precursor_offset\",\n", + " x='precursor_offset',\n", " color=color_palette[7],\n", " fill=True,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", ") # , order=list(range(0,200)))\n", - "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", - "axs[1].set_ylabel(\"\")\n", + "plt.rcParams['patch.force_edgecolor'] = True\n", + "plt.rcParams['axes.edgecolor'] = 'black'\n", + "axs[1].set_ylabel('')\n", "# axs[1].set_xlim([0, 100])\n", "# axs[1].axvline(x=50, color=\"red\", linestyle=\"-.\")\n", "\n", @@ -523,12 +514,12 @@ "\n", "from modules.MOL.collision_energy import align_CE\n", "\n", - "df_nist[\"CE\"] = df_nist.apply(\n", - " lambda x: align_CE(x[\"Collision_energy\"], x[\"theoretical_precursor_mz\"]), axis=1\n", + "df_nist['CE'] = df_nist.apply(\n", + " lambda x: align_CE(x['Collision_energy'], x['theoretical_precursor_mz']), axis=1\n", ") # modules.MOL.collision_energy.align_CE)\n", - "df_nist[\"CE_type\"] = df_nist[\"CE\"].apply(type)\n", - "df_nist[\"CE_derived_from_NCE\"] = df_nist[\"Collision_energy\"].apply(\n", - " lambda x: \"%\" in str(x)\n", + "df_nist['CE_type'] = df_nist['CE'].apply(type)\n", + "df_nist['CE_derived_from_NCE'] = df_nist['Collision_energy'].apply(\n", + " lambda x: '%' in str(x)\n", ")\n", "# df_test = df_nist[df_nist[\"Collision_energy\"].apply(lambda x: \"%\" in str(x))][\"Collision_energy\"]\n", "# df_test = df_test.apply(lambda x: x.split('%')[0].strip().split(' ')[-1])\n", @@ -540,17 +531,17 @@ "# TODO FIND MORE CE derived from different NCE types\n", "\n", "print(\n", - " \"Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)\"\n", + " 'Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)'\n", ")\n", - "print(df_nist[\"CE_type\"].value_counts())\n", + "print(df_nist['CE_type'].value_counts())\n", "\n", - "print(\"Removing all but absolute values\")\n", - "df_nist = df_nist[df_nist[\"CE_type\"] == float]\n", - "df_nist = df_nist[~df_nist[\"CE\"].isnull()]\n", + "print('Removing all but absolute values')\n", + "df_nist = df_nist[df_nist['CE_type'] == float]\n", + "df_nist = df_nist[~df_nist['CE'].isnull()]\n", "# len(df_nist['CE'].unique())\n", "\n", "print(\n", - " f\"Detected {len(df_nist['CE'].unique())} unique collision energies in range from {np.min(df_nist['CE'])} to {max(df_nist['CE'])} eV\"\n", + " f'Detected {len(df_nist[\"CE\"].unique())} unique collision energies in range from {np.min(df_nist[\"CE\"])} to {max(df_nist[\"CE\"])} eV'\n", ")" ] }, @@ -560,7 +551,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_nist[df_nist[\"Instrument_type\"] == \"HCD\"][\"Collision_energy\"].value_counts()[:100]" + "df_nist[df_nist['Instrument_type'] == 'HCD']['Collision_energy'].value_counts()[:100]" ] }, { @@ -577,19 +568,19 @@ "sns.histplot(\n", " ax=ax,\n", " data=df_nist,\n", - " x=\"CE\",\n", - " hue=\"CE_derived_from_NCE\",\n", + " x='CE',\n", + " hue='CE_derived_from_NCE',\n", " palette=[color_palette[4], color_palette[2]],\n", - " multiple=\"stack\",\n", + " multiple='stack',\n", " fill=True,\n", " binwidth=2,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " binrange=[0, 200],\n", ") # , order=list(range(0,200)))\n", - "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", + "plt.rcParams['patch.force_edgecolor'] = True\n", + "plt.rcParams['axes.edgecolor'] = 'black'\n", "plt.show()\n", - "print(f\"{df_nist.shape[0]} spectra remaining with aligned absolute collision energies\")" + "print(f'{df_nist.shape[0]} spectra remaining with aligned absolute collision energies')" ] }, { @@ -610,19 +601,19 @@ "outputs": [], "source": [ "%%capture\n", - "from modules.MOL.Metabolite import Metabolite\n", "from modules.MOL.constants import PPM\n", + "from modules.MOL.Metabolite import Metabolite\n", "\n", "TOLERANCE = 200 * PPM\n", "\n", "\n", - "df_nist[\"Metabolite\"] = df_nist[\"SMILES\"].apply(Metabolite)\n", - "df_nist[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - "df_nist[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes())\n", - "df_nist[\"Metabolite\"].apply(lambda x: x.fragment_MOL())\n", + "df_nist['Metabolite'] = df_nist['SMILES'].apply(Metabolite)\n", + "df_nist['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_nist['Metabolite'].apply(lambda x: x.compute_graph_attributes())\n", + "df_nist['Metabolite'].apply(lambda x: x.fragment_MOL())\n", "df_nist.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=TOLERANCE\n", " ),\n", " axis=1,\n", ")" @@ -639,21 +630,21 @@ "x = df_nist.loc[EXAMPLE_ID]\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 6.4), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + " 1, 2, figsize=(12.8, 6.4), gridspec_kw={'width_ratios': [1, 1]}, sharey=False\n", ")\n", "\n", "img = Chem.Draw.MolToImage(x_mol, ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + Chem.MolToSmiles(x_mol))\n", + "axs[0].set_title(x['Name'] + ' structure:\\n' + Chem.MolToSmiles(x_mol))\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", + "axs[0].axis('off')\n", "\n", - "g_img = draw_graph(x[\"Metabolite\"].Graph, ax=axs[1])\n", - "print(x[\"peaks\"])" + "g_img = draw_graph(x['Metabolite'].Graph, ax=axs[1])\n", + "print(x['peaks'])" ] }, { @@ -691,27 +682,27 @@ "source": [ "x = df_nist.loc[EXAMPLE_ID]\n", "\n", - "FT = x[\"Metabolite\"].fragmentation_tree\n", + "FT = x['Metabolite'].fragmentation_tree\n", "# frag.build_fragmentation_tree_by_rotatable_bond_breaks()\n", "print(FT)\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "\n", - "img = Chem.Draw.MolToImage(x[\"MOL\"], ax=axs[0])\n", + "img = Chem.Draw.MolToImage(x['MOL'], ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(x[\"Name\"] + \" structure:\\n\" + x[\"SMILES\"])\n", + "axs[0].set_title(x['Name'] + ' structure:\\n' + x['SMILES'])\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", - "sv.plot_spectrum(title=x[\"Name\"] + \" MS/MS spectrum\", spectrum=x, ax=axs[1])\n", + "axs[0].axis('off')\n", + "sv.plot_spectrum(title=x['Name'] + ' MS/MS spectrum', spectrum=x, ax=axs[1])\n", "\n", - "print(\"Matching peaks to fragments\")\n", - "print(x[\"Metabolite\"].peak_matches)" + "print('Matching peaks to fragments')\n", + "print(x['Metabolite'].peak_matches)" ] }, { @@ -730,17 +721,17 @@ "source": [ "from modules.MOL.constants import DEFAULT_MODES\n", "\n", - "df_nist[\"peak_matches\"] = df_nist[\"Metabolite\"].apply(\n", - " lambda x: getattr(x, \"peak_matches\")\n", + "df_nist['peak_matches'] = df_nist['Metabolite'].apply(\n", + " lambda x: getattr(x, 'peak_matches')\n", ")\n", - "df_nist[\"num_peaks_matched\"] = df_nist[\"peak_matches\"].apply(len)\n", + "df_nist['num_peaks_matched'] = df_nist['peak_matches'].apply(len)\n", "\n", "\n", "def get_match_stats(matches, mode_count={m: 0 for m in DEFAULT_MODES}):\n", " num_unique, num_conflicts = 0, 0\n", " for mz, match_data in matches.items():\n", " # candidates = match_data[\"fragments\"]\n", - " ion_modes = match_data[\"ion_modes\"]\n", + " ion_modes = match_data['ion_modes']\n", " if len(ion_modes) == 1:\n", " num_unique += 1\n", " elif len(ion_modes) > 1:\n", @@ -750,20 +741,20 @@ " return num_unique, num_conflicts, mode_count\n", "\n", "\n", - "df_nist[\"match_stats\"] = df_nist[\"peak_matches\"].apply(lambda x: get_match_stats(x))\n", - "df_nist[\"num_unique_peaks_matched\"] = df_nist.apply(\n", - " lambda x: x[\"match_stats\"][0], axis=1\n", + "df_nist['match_stats'] = df_nist['peak_matches'].apply(lambda x: get_match_stats(x))\n", + "df_nist['num_unique_peaks_matched'] = df_nist.apply(\n", + " lambda x: x['match_stats'][0], axis=1\n", ")\n", - "df_nist[\"num_conflicts_in_peak_matching\"] = df_nist.apply(\n", - " lambda x: x[\"match_stats\"][1], axis=1\n", + "df_nist['num_conflicts_in_peak_matching'] = df_nist.apply(\n", + " lambda x: x['match_stats'][1], axis=1\n", ")\n", - "df_nist[\"match_mode_counts\"] = df_nist.apply(lambda x: x[\"match_stats\"][2], axis=1)\n", - "u = df_nist[\"num_unique_peaks_matched\"].sum()\n", - "s = df_nist[\"num_conflicts_in_peak_matching\"].sum()\n", + "df_nist['match_mode_counts'] = df_nist.apply(lambda x: x['match_stats'][2], axis=1)\n", + "u = df_nist['num_unique_peaks_matched'].sum()\n", + "s = df_nist['num_conflicts_in_peak_matching'].sum()\n", "print(\n", - " f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))\"\n", + " f'Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))'\n", ")\n", - "print(f\"Total number of conflicting peak to fragment matches: {s}\")\n", + "print(f'Total number of conflicting peak to fragment matches: {s}')\n", "\n", "df_nist.shape" ] @@ -776,34 +767,34 @@ "source": [ "fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", - "fig.suptitle(f\"Identified peaks with fragment offset\")\n", + "fig.suptitle('Identified peaks with fragment offset')\n", "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", "sns.histplot(\n", " ax=axs[0],\n", " data=df_nist,\n", - " x=\"num_peaks_matched\",\n", + " x='num_peaks_matched',\n", " color=color_palette[0],\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " bins=range(0, 20, 1),\n", ")\n", "# axs[0].set_ylim(-0.5, 10)\n", - "axs[0].set_ylabel(\"peaks identified\")\n", + "axs[0].set_ylabel('peaks identified')\n", "\n", "\n", "sns.boxplot(\n", - " ax=axs[1], data=df_nist, y=\"num_unique_peaks_matched\", color=color_palette[1]\n", + " ax=axs[1], data=df_nist, y='num_unique_peaks_matched', color=color_palette[1]\n", ")\n", "axs[1].set_ylim(-0.5, 15)\n", - "axs[1].set_xlabel(\"unique matches\")\n", - "axs[1].set_ylabel(\"\")\n", + "axs[1].set_xlabel('unique matches')\n", + "axs[1].set_ylabel('')\n", "\n", "\n", "sns.boxplot(\n", - " ax=axs[2], data=df_nist, y=\"num_conflicts_in_peak_matching\", color=color_palette[3]\n", + " ax=axs[2], data=df_nist, y='num_conflicts_in_peak_matching', color=color_palette[3]\n", ")\n", "axs[2].set_ylim(-0.5, 15)\n", - "axs[2].set_xlabel(\"conflicts\")\n", - "axs[2].set_ylabel(\"\")\n", + "axs[2].set_xlabel('conflicts')\n", + "axs[2].set_ylabel('')\n", "\n", "plt.show()" ] @@ -824,21 +815,21 @@ " mode_counts[mode] += m[mode]\n", "\n", "\n", - "df_nist[\"match_mode_counts\"].apply(update_mode_counts)\n", + "df_nist['match_mode_counts'].apply(update_mode_counts)\n", "\n", "sns.barplot(\n", " ax=axs[0],\n", " x=list(mode_counts.keys()),\n", " y=[mode_counts[k] for k in mode_counts.keys()],\n", " palette=color_palette,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " linewidth=1.5,\n", ")\n", "axs[1].pie(\n", " [mode_counts[k] for k in mode_counts.keys()],\n", " labels=list(mode_counts.keys()),\n", " colors=color_palette,\n", - " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + " wedgeprops={'edgecolor': 'black', 'linewidth': 1.5},\n", ")\n", "\n", "plt.show()" @@ -852,8 +843,8 @@ "source": [ "for i in range(0, 6):\n", " print(\n", - " f\"Minimum {i} unique peaks identified (including precursors): \",\n", - " (df_nist[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " f'Minimum {i} unique peaks identified (including precursors): ',\n", + " (df_nist['num_unique_peaks_matched'] >= i).sum(),\n", " )" ] }, @@ -884,23 +875,22 @@ "\n", "MIN_PEAKS = 2\n", "MAX_PEAKS = 30\n", - "PRECURSOR_TYPES = [\"[M-H]-\"]\n", + "PRECURSOR_TYPES = ['[M-H]-']\n", "from modules.MOL.constants import ADDUCT_WEIGHTS\n", "\n", - "\n", - "df_minus = df_minus[df_minus[\"Num peaks\"] > MIN_PEAKS]\n", - "df_minus = df_minus[df_minus[\"Num peaks\"] < MAX_PEAKS] # TODO WHY MAX CUTOFF: REMOVE!!\n", - "df_minus[\"theoretical_precursor_mz\"] = df_minus[\"ExactMolWeight\"] + df_minus[\n", - " \"Precursor_type\"\n", + "df_minus = df_minus[df_minus['Num peaks'] > MIN_PEAKS]\n", + "df_minus = df_minus[df_minus['Num peaks'] < MAX_PEAKS] # TODO WHY MAX CUTOFF: REMOVE!!\n", + "df_minus['theoretical_precursor_mz'] = df_minus['ExactMolWeight'] + df_minus[\n", + " 'Precursor_type'\n", "].map(ADDUCT_WEIGHTS)\n", "df_minus = df_minus[\n", - " df_minus[\"Precursor_type\"].apply(lambda ptype: ptype in PRECURSOR_TYPES)\n", + " df_minus['Precursor_type'].apply(lambda ptype: ptype in PRECURSOR_TYPES)\n", "]\n", - "df_minus[\"precursor_offset\"] = (\n", - " df_minus[\"PrecursorMZ\"] - df_minus[\"theoretical_precursor_mz\"]\n", + "df_minus['precursor_offset'] = (\n", + " df_minus['PrecursorMZ'] - df_minus['theoretical_precursor_mz']\n", ")\n", "\n", - "print(f\"Shape {df_minus.shape}\")" + "print(f'Shape {df_minus.shape}')" ] }, { @@ -924,28 +914,27 @@ "\n", "from modules.MOL.collision_energy import NCE_to_eV # align_CE,\n", "\n", - "\n", - "df_minus[\"CE\"] = df_minus.apply(\n", - " lambda x: align_CE(x[\"Collision_energy\"], x[\"theoretical_precursor_mz\"]), axis=1\n", + "df_minus['CE'] = df_minus.apply(\n", + " lambda x: align_CE(x['Collision_energy'], x['theoretical_precursor_mz']), axis=1\n", ") # modules.MOL.collision_energy.align_CE)\n", - "df_minus[\"CE_type\"] = df_minus[\"CE\"].apply(type)\n", - "df_minus[\"CE_derived_from_NCE\"] = df_minus[\"Collision_energy\"].apply(\n", - " lambda x: \"%\" in str(x)\n", + "df_minus['CE_type'] = df_minus['CE'].apply(type)\n", + "df_minus['CE_derived_from_NCE'] = df_minus['Collision_energy'].apply(\n", + " lambda x: '%' in str(x)\n", ")\n", "\n", "\n", "print(\n", - " \"Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)\"\n", + " 'Distinguish CE absolute values (eV - float) and normalized CE (in % - str format)'\n", ")\n", - "print(df_minus[\"CE_type\"].value_counts())\n", + "print(df_minus['CE_type'].value_counts())\n", "\n", - "print(\"Removing all but absolute values\")\n", - "df_minus = df_minus[df_minus[\"CE_type\"] == float]\n", - "df_minus = df_minus[~df_minus[\"CE\"].isnull()]\n", + "print('Removing all but absolute values')\n", + "df_minus = df_minus[df_minus['CE_type'] == float]\n", + "df_minus = df_minus[~df_minus['CE'].isnull()]\n", "# len(df_nist['CE'].unique())\n", "\n", "print(\n", - " f\"Detected {len(df_minus['CE'].unique())} unique collision energies in range from {np.min(df_minus['CE'])} to {max(df_minus['CE'])} eV\"\n", + " f'Detected {len(df_minus[\"CE\"].unique())} unique collision energies in range from {np.min(df_minus[\"CE\"])} to {max(df_minus[\"CE\"])} eV'\n", ")" ] }, @@ -963,19 +952,19 @@ "sns.histplot(\n", " ax=ax,\n", " data=df_minus,\n", - " x=\"CE\",\n", - " hue=\"CE_derived_from_NCE\",\n", + " x='CE',\n", + " hue='CE_derived_from_NCE',\n", " palette=[color_palette[4], color_palette[2]],\n", - " multiple=\"stack\",\n", + " multiple='stack',\n", " fill=True,\n", " binwidth=2,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " binrange=[0, 200],\n", ") # , order=list(range(0,200)))\n", - "plt.rcParams[\"patch.force_edgecolor\"] = True\n", - "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n", + "plt.rcParams['patch.force_edgecolor'] = True\n", + "plt.rcParams['axes.edgecolor'] = 'black'\n", "plt.show()\n", - "print(f\"{df_minus.shape[0]} spectra remaining with aligned absolute collision energies\")" + "print(f'{df_minus.shape[0]} spectra remaining with aligned absolute collision energies')" ] }, { @@ -996,15 +985,15 @@ "outputs": [], "source": [ "%%capture\n", - "df_minus[\"Metabolite\"] = df_minus[\"SMILES\"].apply(Metabolite)\n", - "df_minus[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - "df_minus[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes())\n", - "df_minus[\"Metabolite\"].apply(lambda x: x.fragment_MOL())\n", + "df_minus['Metabolite'] = df_minus['SMILES'].apply(Metabolite)\n", + "df_minus['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_minus['Metabolite'].apply(lambda x: x.compute_graph_attributes())\n", + "df_minus['Metabolite'].apply(lambda x: x.fragment_MOL())\n", "# df_minus[\"Metabolite\"].apply(lambda x: x.fragmentation_tree.set_fragment_modes(constants.NEGATIVE_MODES))\n", "\n", "df_minus.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=TOLERANCE\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=TOLERANCE\n", " ),\n", " axis=1,\n", ")" @@ -1052,26 +1041,26 @@ } ], "source": [ - "df_minus[\"peak_matches\"] = df_minus[\"Metabolite\"].apply(\n", - " lambda x: getattr(x, \"peak_matches\")\n", + "df_minus['peak_matches'] = df_minus['Metabolite'].apply(\n", + " lambda x: getattr(x, 'peak_matches')\n", ")\n", - "df_minus[\"num_peaks_matched\"] = df_minus[\"peak_matches\"].apply(len)\n", + "df_minus['num_peaks_matched'] = df_minus['peak_matches'].apply(len)\n", "\n", "\n", - "df_minus[\"match_stats\"] = df_minus[\"peak_matches\"].apply(lambda x: get_match_stats(x))\n", - "df_minus[\"num_unique_peaks_matched\"] = df_minus.apply(\n", - " lambda x: x[\"match_stats\"][0], axis=1\n", + "df_minus['match_stats'] = df_minus['peak_matches'].apply(lambda x: get_match_stats(x))\n", + "df_minus['num_unique_peaks_matched'] = df_minus.apply(\n", + " lambda x: x['match_stats'][0], axis=1\n", ")\n", - "df_minus[\"num_conflicts_in_peak_matching\"] = df_minus.apply(\n", - " lambda x: x[\"match_stats\"][1], axis=1\n", + "df_minus['num_conflicts_in_peak_matching'] = df_minus.apply(\n", + " lambda x: x['match_stats'][1], axis=1\n", ")\n", - "df_minus[\"match_mode_counts\"] = df_minus.apply(lambda x: x[\"match_stats\"][2], axis=1)\n", - "u = df_minus[\"num_unique_peaks_matched\"].sum()\n", - "s = df_minus[\"num_conflicts_in_peak_matching\"].sum()\n", + "df_minus['match_mode_counts'] = df_minus.apply(lambda x: x['match_stats'][2], axis=1)\n", + "u = df_minus['num_unique_peaks_matched'].sum()\n", + "s = df_minus['num_conflicts_in_peak_matching'].sum()\n", "print(\n", - " f\"Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))\"\n", + " f'Total number of uniquely matched peaks: {u} , conflicts found within {s} matches ({100 * s / (u + s):.02f} %))'\n", ")\n", - "print(f\"Total number of conflicting peak to fragment matches: {s}\")\n", + "print(f'Total number of conflicting peak to fragment matches: {s}')\n", "\n", "df_minus.shape" ] @@ -1084,34 +1073,34 @@ "source": [ "fig, axs = plt.subplots(1, 3, figsize=(12.8, 6.4), sharey=False)\n", "\n", - "fig.suptitle(f\"Identified peaks with fragment offset\")\n", + "fig.suptitle('Identified peaks with fragment offset')\n", "# plt.title(f\"Identified peaks with fragment offset: {str(off)}\")\n", "sns.histplot(\n", " ax=axs[0],\n", " data=df_minus,\n", - " x=\"num_peaks_matched\",\n", + " x='num_peaks_matched',\n", " color=color_palette[0],\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " bins=range(0, 20, 1),\n", ")\n", "# axs[0].set_ylim(-0.5, 10)\n", - "axs[0].set_ylabel(\"peaks identified\")\n", + "axs[0].set_ylabel('peaks identified')\n", "\n", "\n", "sns.boxplot(\n", - " ax=axs[1], data=df_minus, y=\"num_unique_peaks_matched\", color=color_palette[1]\n", + " ax=axs[1], data=df_minus, y='num_unique_peaks_matched', color=color_palette[1]\n", ")\n", "axs[1].set_ylim(-0.5, 15)\n", - "axs[1].set_xlabel(\"unique matches\")\n", - "axs[1].set_ylabel(\"\")\n", + "axs[1].set_xlabel('unique matches')\n", + "axs[1].set_ylabel('')\n", "\n", "\n", "sns.boxplot(\n", - " ax=axs[2], data=df_minus, y=\"num_conflicts_in_peak_matching\", color=color_palette[3]\n", + " ax=axs[2], data=df_minus, y='num_conflicts_in_peak_matching', color=color_palette[3]\n", ")\n", "axs[2].set_ylim(-0.5, 15)\n", - "axs[2].set_xlabel(\"conflicts\")\n", - "axs[2].set_ylabel(\"\")\n", + "axs[2].set_xlabel('conflicts')\n", + "axs[2].set_ylabel('')\n", "\n", "plt.show()" ] @@ -1132,10 +1121,10 @@ " mode_counts[mode] += m[mode]\n", "\n", "\n", - "df_minus[\"match_mode_counts\"].apply(update_mode_counts)\n", + "df_minus['match_mode_counts'].apply(update_mode_counts)\n", "\n", "mode_counts = dict(\n", - " (key.replace(\"]+\", \"]-\"), value) for (key, value) in mode_counts.items()\n", + " (key.replace(']+', ']-'), value) for (key, value) in mode_counts.items()\n", ")\n", "\n", "sns.barplot(\n", @@ -1143,14 +1132,14 @@ " x=list(mode_counts.keys()),\n", " y=[mode_counts[k] for k in mode_counts.keys()],\n", " palette=color_palette,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " linewidth=1.5,\n", ")\n", "axs[1].pie(\n", " [mode_counts[k] for k in mode_counts.keys()],\n", " labels=list(mode_counts.keys()),\n", " colors=color_palette,\n", - " wedgeprops={\"edgecolor\": \"black\", \"linewidth\": 1.5},\n", + " wedgeprops={'edgecolor': 'black', 'linewidth': 1.5},\n", ")\n", "\n", "plt.show()" @@ -1164,8 +1153,8 @@ "source": [ "for i in range(0, 6):\n", " print(\n", - " f\"Minimum {i} unique peaks identified (including precursors): \",\n", - " (df_minus[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " f'Minimum {i} unique peaks identified (including precursors): ',\n", + " (df_minus['num_unique_peaks_matched'] >= i).sum(),\n", " )" ] }, @@ -1185,7 +1174,7 @@ "source": [ "df = pd.concat([df_nist, df_minus], axis=0)\n", "\n", - "df[\"Precursor_type\"].value_counts()" + "df['Precursor_type'].value_counts()" ] }, { @@ -1196,8 +1185,8 @@ "source": [ "for i in range(0, 6):\n", " print(\n", - " f\"Minimum {i} unique peaks identified (including precursors): \",\n", - " (df[\"num_unique_peaks_matched\"] >= i).sum(),\n", + " f'Minimum {i} unique peaks identified (including precursors): ',\n", + " (df['num_unique_peaks_matched'] >= i).sum(),\n", " )" ] }, @@ -1208,61 +1197,61 @@ "outputs": [], "source": [ "save_df = False\n", - "name = \"nist_msms_filtered\"\n", - "date = \"XXX\" # \"07_2023\"\n", + "name = 'nist_msms_filtered'\n", + "date = 'XXX' # \"07_2023\"\n", "min_peaks = 1\n", "\n", "if save_df:\n", " key_columns = [\n", - " \"Name\",\n", - " \"Synon\",\n", - " \"Notes\",\n", - " \"Precursor_type\",\n", - " \"Spectrum_type\",\n", - " \"PrecursorMZ\",\n", - " \"Instrument_type\",\n", - " \"Instrument\",\n", - " \"Sample_inlet\",\n", - " \"Ionization\",\n", - " \"Collision_energy\",\n", - " \"Ion_mode\",\n", - " \"Special_fragmentation\",\n", - " \"InChIKey\",\n", - " \"Formula\",\n", - " \"MW\",\n", - " \"ExactMass\",\n", - " \"CASNO\",\n", - " \"NISTNO\",\n", - " \"ID\",\n", - " \"Comment\",\n", - " \"Num peaks\",\n", - " \"peaks\",\n", - " \"Link\",\n", - " \"Related_CAS#\",\n", - " \"Collision_gas\",\n", - " \"Pressure\",\n", - " \"In-source_voltage\",\n", - " \"msN_pathway\",\n", - " \"MOL\",\n", - " \"SMILES\",\n", - " \"InChI\",\n", - " \"K\",\n", - " \"ExactMolWeight\",\n", - " \"matching_key\",\n", - " \"theoretical_precursor_mz\",\n", - " \"precursor_offset\",\n", - " \"CE\",\n", - " \"CE_type\",\n", - " \"peak_matches\",\n", - " \"num_peaks_matched\",\n", - " \"match_stats\",\n", - " \"num_unique_peaks_matched\",\n", - " \"num_conflicts_in_peak_matching\",\n", - " \"match_mode_counts\",\n", + " 'Name',\n", + " 'Synon',\n", + " 'Notes',\n", + " 'Precursor_type',\n", + " 'Spectrum_type',\n", + " 'PrecursorMZ',\n", + " 'Instrument_type',\n", + " 'Instrument',\n", + " 'Sample_inlet',\n", + " 'Ionization',\n", + " 'Collision_energy',\n", + " 'Ion_mode',\n", + " 'Special_fragmentation',\n", + " 'InChIKey',\n", + " 'Formula',\n", + " 'MW',\n", + " 'ExactMass',\n", + " 'CASNO',\n", + " 'NISTNO',\n", + " 'ID',\n", + " 'Comment',\n", + " 'Num peaks',\n", + " 'peaks',\n", + " 'Link',\n", + " 'Related_CAS#',\n", + " 'Collision_gas',\n", + " 'Pressure',\n", + " 'In-source_voltage',\n", + " 'msN_pathway',\n", + " 'MOL',\n", + " 'SMILES',\n", + " 'InChI',\n", + " 'K',\n", + " 'ExactMolWeight',\n", + " 'matching_key',\n", + " 'theoretical_precursor_mz',\n", + " 'precursor_offset',\n", + " 'CE',\n", + " 'CE_type',\n", + " 'peak_matches',\n", + " 'num_peaks_matched',\n", + " 'match_stats',\n", + " 'num_unique_peaks_matched',\n", + " 'num_conflicts_in_peak_matching',\n", + " 'match_mode_counts',\n", " ]\n", - " file = library_directory + name + \"_min\" + str(min_peaks) + \"_\" + date + \".csv\"\n", - " print(\"saving to \", file)\n", - " df[df[\"num_unique_peaks_matched\"] >= min_peaks][key_columns].to_csv(file)\n", + " file = library_directory + name + '_min' + str(min_peaks) + '_' + date + '.csv'\n", + " print('saving to ', file)\n", + " df[df['num_unique_peaks_matched'] >= min_peaks][key_columns].to_csv(file)\n", "\n", " # df[key_columns].to_csv(library_directory + name + \"all\" + \"_\" + date + \".csv\")" ] diff --git a/notebooks/break_tendency.ipynb b/notebooks/break_tendency.ipynb index 2bf97c2..6468bb0 100644 --- a/notebooks/break_tendency.ipynb +++ b/notebooks/break_tendency.ipynb @@ -24,6 +24,7 @@ ], "source": [ "import sys\n", + "\n", "import torch\n", "\n", "seed = 42\n", @@ -32,32 +33,31 @@ "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", "\n", - "import pandas as pd\n", - "import numpy as np\n", "import ast\n", "\n", "# Plotting\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", "import seaborn as sns\n", "\n", - "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", - "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree\n", - "import fiora.visualization.spectrum_visualizer as sv\n", - "\n", - "from sklearn.metrics import r2_score\n", + "home = expanduser('~')\n", "import scipy\n", "from rdkit import RDLogger\n", + "from sklearn.metrics import r2_score\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "import fiora.visualization.spectrum_visualizer as sv\n", + "from fiora.IO.LibraryLoader import LibraryLoader\n", + "from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM, PPM\n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", "\n", - "print(f\"Working with Python {sys.version}\")" + "RDLogger.DisableLog('rdApp.*')\n", + "\n", + "print(f'Working with Python {sys.version}')" ] }, { @@ -84,13 +84,13 @@ "source": [ "from typing import Literal\n", "\n", - "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\"] = \"NIST/MSDIAL\"\n", - "print(f\"Preparing {lib} library\")\n", + "lib: Literal['NIST', 'MSDIAL', 'NIST/MSDIAL'] = 'NIST/MSDIAL'\n", + "print(f'Preparing {lib} library')\n", "\n", "test_run = False # Default: False\n", "if test_run:\n", " print(\n", - " \"+++ This is a test run with a small subset of data points. Results are not representative. +++\"\n", + " '+++ This is a test run with a small subset of data points. Results are not representative. +++'\n", " )" ] }, @@ -102,14 +102,14 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\",\n", + " 'name': 'Name',\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'ionization': 'Ionization',\n", + " 'precursor_mz': 'PrecursorMZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", + " 'ccs': 'CCS',\n", "}\n", "\n", "\n", @@ -119,58 +119,58 @@ "\n", "\n", "def load_nist():\n", - " library_name = \"nist_msms_filteredall_07_2023\"\n", - " library_directory = f\"{home}/data/metabolites/NIST17/msp/nist_msms/\"\n", + " library_name = 'nist_msms_filteredall_07_2023'\n", + " library_directory = f'{home}/data/metabolites/NIST17/msp/nist_msms/'\n", " L = LibraryLoader()\n", - " df = L.load_from_csv(library_directory + library_name + \".csv\")\n", - " df[\"RETENTIONTIME\"] = np.nan\n", - " df[\"CCS\"] = np.nan\n", - " df[\"PPM_num\"] = 50\n", - " df[\"ppm_peak_tolerance\"] = df[\"PPM_num\"] * PPM\n", - " df[\"lib\"] = \"NIST\"\n", - " df[\"origin\"] = \"NIST\"\n", + " df = L.load_from_csv(library_directory + library_name + '.csv')\n", + " df['RETENTIONTIME'] = np.nan\n", + " df['CCS'] = np.nan\n", + " df['PPM_num'] = 50\n", + " df['ppm_peak_tolerance'] = df['PPM_num'] * PPM\n", + " df['lib'] = 'NIST'\n", + " df['origin'] = 'NIST'\n", "\n", " return df\n", "\n", "\n", "def load_msdial():\n", - " library_name = \"ms_dial_filtered_all_mid_08_2023\"\n", - " library_directory = f\"{home}/data/metabolites/MS_DIAL/\"\n", + " library_name = 'ms_dial_filtered_all_mid_08_2023'\n", + " library_directory = f'{home}/data/metabolites/MS_DIAL/'\n", " L = LibraryLoader()\n", - " df = L.load_from_csv(library_directory + library_name + \".csv\")\n", + " df = L.load_from_csv(library_directory + library_name + '.csv')\n", "\n", - " orbitrap_nametags = [\"Orbitrap\"]\n", - " qtof_nametags = [\"QTOF\", \"LC-ESI-QTOF\", \"ESI-QTOF\"]\n", - " df[\"Instrument_type\"] = df[\"INSTRUMENTTYPE\"].apply(\n", + " orbitrap_nametags = ['Orbitrap']\n", + " qtof_nametags = ['QTOF', 'LC-ESI-QTOF', 'ESI-QTOF']\n", + " df['Instrument_type'] = df['INSTRUMENTTYPE'].apply(\n", " lambda x: (\n", - " \"HCD\" if x in orbitrap_nametags else \"Q-TOF\" if x in qtof_nametags else x\n", + " 'HCD' if x in orbitrap_nametags else 'Q-TOF' if x in qtof_nametags else x\n", " )\n", " )\n", - " df[\"Ionization\"] = \"ESI\"\n", - " df[\"original_RT\"] = df[\"RETENTIONTIME\"].astype(float)\n", - " df[\"RETENTIONTIME\"] = df[\"RETENTIONTIME\"].astype(float)\n", - " df[\"CCS\"] = df[\"CCS\"].astype(float)\n", - " df[\"PrecursorMZ\"] = df[\"PRECURSORMZ\"].astype(float)\n", - " df[\"Precursor_type\"] = df[\"PRECURSORTYPE\"]\n", - " df[\"Name\"] = df[\"NAME\"]\n", - " df[\"PPM_num\"] = 50\n", - " df[\"ppm_peak_tolerance\"] = df[\"PPM_num\"] * PPM\n", - " df[\"lib\"] = \"MSDIAL\"\n", + " df['Ionization'] = 'ESI'\n", + " df['original_RT'] = df['RETENTIONTIME'].astype(float)\n", + " df['RETENTIONTIME'] = df['RETENTIONTIME'].astype(float)\n", + " df['CCS'] = df['CCS'].astype(float)\n", + " df['PrecursorMZ'] = df['PRECURSORMZ'].astype(float)\n", + " df['Precursor_type'] = df['PRECURSORTYPE']\n", + " df['Name'] = df['NAME']\n", + " df['PPM_num'] = 50\n", + " df['ppm_peak_tolerance'] = df['PPM_num'] * PPM\n", + " df['lib'] = 'MSDIAL'\n", "\n", " # Filter out retention times using other phase types (e.g. HILIC) or of unknown/heterogeneous souces\n", " # df[df[\"origin\"] == \"MassBank High Quality Mass Spectral Database\"][\"RETENTIONTIME\"] = np.nan\n", " # df[df[\"origin\"] == \"Fiehn Lab HILIC Library\"][\"RETENTIONTIME\"] = np.nan\n", " bad_RT_libs = [\n", - " \"MassBank High Quality Mass Spectral Database\",\n", - " \"Fiehn Lab HILIC Library\",\n", + " 'MassBank High Quality Mass Spectral Database',\n", + " 'Fiehn Lab HILIC Library',\n", " ]\n", " potential_homogenous_RT_libs = [\n", - " \"BMDMS-NP\"\n", + " 'BMDMS-NP'\n", " ] # , 'RIKEN Plant Specialized Metabolome Annotation (PlaSMA) Authentic Standard Library' 'BMDMS-NP' , \"Global Natural Product Social Molecular Networking Library\"]\n", - " df[\"RETENTIONTIME\"] = df.apply(\n", + " df['RETENTIONTIME'] = df.apply(\n", " lambda x: (\n", - " x[\"RETENTIONTIME\"]\n", - " if x[\"origin\"] in potential_homogenous_RT_libs\n", + " x['RETENTIONTIME']\n", + " if x['origin'] in potential_homogenous_RT_libs\n", " else np.nan\n", " ),\n", " axis=1,\n", @@ -179,16 +179,16 @@ " return df\n", "\n", "\n", - "if lib == \"NIST\":\n", + "if lib == 'NIST':\n", " df = load_nist()\n", - "elif lib == \"MSDIAL\":\n", + "elif lib == 'MSDIAL':\n", " df = load_msdial()\n", - "elif lib == \"NIST/MSDIAL\":\n", + "elif lib == 'NIST/MSDIAL':\n", " df = pd.concat([load_nist(), load_msdial()], ignore_index=True)\n", " # df.reset_index(inplace=True) # Avoid conflict from index overlap of the two dataframes\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\"]\n", + "dict_columns = ['peaks']\n", "for col in dict_columns:\n", " df[col] = df[col].apply(ast.literal_eval)" ] @@ -221,12 +221,12 @@ ], "source": [ "# %%capture\n", - "from fiora.MOL.Metabolite import Metabolite\n", + "from fiora.GNN.SetupFeatureEncoder import SetupFeatureEncoder\n", + "\n", "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", - "from fiora.GNN.SetupFeatureEncoder import SetupFeatureEncoder\n", + "from fiora.MOL.Metabolite import Metabolite\n", "\n", - "#\n", "filter_spectra = True\n", "CE_upper_limit = 100.0\n", "weight_upper_limit = 1000.0\n", @@ -237,51 +237,51 @@ " # df = df.iloc[5000:20000,:]\n", "\n", "\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", - "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df['Metabolite'] = df['SMILES'].apply(Metabolite)\n", + "df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "setup_encoder = SetupFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", " ]\n", ")\n", "rt_encoder = SetupFeatureEncoder(\n", - " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"]\n", + " feature_list=['molecular_weight', 'precursor_mode', 'instrument']\n", ")\n", "\n", "if filter_spectra:\n", - " setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - " setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - " rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", + " setup_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + " setup_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + " rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + "df['Metabolite'].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", "\n", - "df[\"summary\"] = df.apply(\n", + "df['summary'] = df.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", ")\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", ")\n", "\n", "if filter_spectra:\n", " num_ori = df.shape[0]\n", - " correct_energy = df[\"Metabolite\"].apply(\n", + " correct_energy = df['Metabolite'].apply(\n", " lambda x: (\n", - " x.metadata[\"collision_energy\"] <= CE_upper_limit\n", - " and x.metadata[\"collision_energy\"] > 1\n", + " x.metadata['collision_energy'] <= CE_upper_limit\n", + " and x.metadata['collision_energy'] > 1\n", " )\n", " )\n", " df = df[correct_energy]\n", - " correct_weight = df[\"Metabolite\"].apply(\n", - " lambda x: x.metadata[\"molecular_weight\"] <= weight_upper_limit\n", + " correct_weight = df['Metabolite'].apply(\n", + " lambda x: x.metadata['molecular_weight'] <= weight_upper_limit\n", " )\n", " df = df[correct_weight]\n", - " print(f\"Filtering spectra ({num_ori}) down to {df.shape[0]}\")\n", + " print(f'Filtering spectra ({num_ori}) down to {df.shape[0]}')\n", " # print(df[\"Precursor_type\"].value_counts())" ] }, @@ -293,21 +293,21 @@ "source": [ "import pandas as pd\n", "\n", - "casmi16_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_challenges_combined.csv\"\n", - "casmi16train_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_training_combined.csv\"\n", - "casmi22_path = f\"{home}/data/metabolites/CASMI_2022/casmi22_challenges_combined.csv\"\n", + "casmi16_path = f'{home}/data/metabolites/CASMI_2016/casmi16_challenges_combined.csv'\n", + "casmi16train_path = f'{home}/data/metabolites/CASMI_2016/casmi16_training_combined.csv'\n", + "casmi22_path = f'{home}/data/metabolites/CASMI_2022/casmi22_challenges_combined.csv'\n", "\n", "df_cas = pd.read_csv(casmi16_path, index_col=[0], low_memory=False)\n", "df_cast = pd.read_csv(casmi16train_path, index_col=[0], low_memory=False)\n", "df_cas22 = pd.read_csv(casmi22_path, index_col=[0], low_memory=False)\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"Candidates\"]\n", + "dict_columns = ['peaks', 'Candidates']\n", "for col in dict_columns:\n", " df_cas[col] = df_cas[col].apply(ast.literal_eval)\n", " df_cast[col] = df_cast[col].apply(ast.literal_eval)\n", "\n", - "df_cas22[\"peaks\"] = df_cas22[\"peaks\"].apply(ast.literal_eval)\n", + "df_cas22['peaks'] = df_cas22['peaks'].apply(ast.literal_eval)\n", "df_cas22 = df_cas22.reset_index()" ] }, @@ -411,36 +411,36 @@ } ], "source": [ - "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", - "df_cas[\"CCS\"] = [[]] * df_cas.shape[0]\n", + "df_cas['Metabolite'] = df_cas['SMILES'].apply(Metabolite)\n", + "df_cas['CCS'] = [[]] * df_cas.shape[0]\n", "### CHECK IF train test split is correct\n", "\n", "iii = []\n", "xxx = []\n", "for i, d in df_cas.iterrows():\n", - " m = d[\"Metabolite\"]\n", + " m = d['Metabolite']\n", "\n", " for x, D in df.iterrows():\n", - " M = D[\"Metabolite\"]\n", + " M = D['Metabolite']\n", " if m == M:\n", " iii += [i]\n", " xxx += [x]\n", - " if D[\"CCS\"]:\n", - " df_cas.at[i, \"CCS\"] = df_cas.at[i, \"CCS\"] + [\n", - " D[\"CCS\"]\n", + " if D['CCS']:\n", + " df_cas.at[i, 'CCS'] = df_cas.at[i, 'CCS'] + [\n", + " D['CCS']\n", " ] # Add CCS metadata\n", "\n", "\n", "iii = np.unique(iii)\n", "print(\n", - " f\"Found {len(iii)} instances violating test/train split (CASMI 16 Challenge). Metabolite found in train/val set.\"\n", + " f'Found {len(iii)} instances violating test/train split (CASMI 16 Challenge). Metabolite found in train/val set.'\n", ")\n", - "print(f\"Dropping {len(xxx)} spectra from training DataFrame.\")\n", + "print(f'Dropping {len(xxx)} spectra from training DataFrame.')\n", "df.drop(xxx, inplace=True)\n", "\n", "# Add CCS metadata\n", - "df_cas[\"CCS_std\"] = df_cas[\"CCS\"].apply(np.std)\n", - "df_cas[\"CCS\"] = df_cas[\"CCS\"].apply(np.mean)" + "df_cas['CCS_std'] = df_cas['CCS'].apply(np.std)\n", + "df_cas['CCS'] = df_cas['CCS'].apply(np.mean)" ] }, { @@ -474,36 +474,36 @@ } ], "source": [ - "df_cast[\"Metabolite\"] = df_cast[\"SMILES\"].apply(Metabolite)\n", - "df_cast[\"CCS\"] = [[]] * df_cast.shape[0]\n", + "df_cast['Metabolite'] = df_cast['SMILES'].apply(Metabolite)\n", + "df_cast['CCS'] = [[]] * df_cast.shape[0]\n", "### CHECK IF train test split is correct\n", "\n", "iii = []\n", "xxx = []\n", "for i, d in df_cast.iterrows():\n", - " m = d[\"Metabolite\"]\n", + " m = d['Metabolite']\n", "\n", " for x, D in df.iterrows():\n", - " M = D[\"Metabolite\"]\n", + " M = D['Metabolite']\n", " if m == M:\n", " iii += [i]\n", " xxx += [x]\n", - " if D[\"CCS\"]:\n", - " df_cast.at[i, \"CCS\"] = df_cast.at[i, \"CCS\"] + [\n", - " D[\"CCS\"]\n", + " if D['CCS']:\n", + " df_cast.at[i, 'CCS'] = df_cast.at[i, 'CCS'] + [\n", + " D['CCS']\n", " ] # Add CCS metadata\n", "\n", "\n", "iii = np.unique(iii)\n", "print(\n", - " f\"Found {len(iii)} instances violating test/train split (CASMI 16 Training). Metabolite found in train/val set.\"\n", + " f'Found {len(iii)} instances violating test/train split (CASMI 16 Training). Metabolite found in train/val set.'\n", ")\n", - "print(f\"Dropping {len(xxx)} spectra from training DataFrame.\")\n", + "print(f'Dropping {len(xxx)} spectra from training DataFrame.')\n", "df.drop(xxx, inplace=True)\n", "\n", "# Add CCS metadata\n", - "df_cast[\"CCS_std\"] = df_cast[\"CCS\"].apply(np.std)\n", - "df_cast[\"CCS\"] = df_cast[\"CCS\"].apply(np.mean)" + "df_cast['CCS_std'] = df_cast['CCS'].apply(np.std)\n", + "df_cast['CCS'] = df_cast['CCS'].apply(np.mean)" ] }, { @@ -535,9 +535,9 @@ } ], "source": [ - "df_cas22_unique = df_cas22.drop_duplicates(subset=\"ChallengeName\", keep=\"first\")\n", + "df_cas22_unique = df_cas22.drop_duplicates(subset='ChallengeName', keep='first')\n", "df_cas22_unique.reset_index(inplace=True)\n", - "df_cas22_unique[\"Metabolite\"] = df_cas22_unique[\"SMILES\"].apply(Metabolite)\n", + "df_cas22_unique['Metabolite'] = df_cas22_unique['SMILES'].apply(Metabolite)\n", "df_cas22_unique.shape" ] }, @@ -601,38 +601,38 @@ "iii = []\n", "xxx = []\n", "\n", - "df_cas22_unique[\"CCS\"] = [[]] * df_cas22_unique.shape[0]\n", + "df_cas22_unique['CCS'] = [[]] * df_cas22_unique.shape[0]\n", "\n", "\n", "for i, d in df_cas22_unique.iterrows():\n", - " m = d[\"Metabolite\"]\n", + " m = d['Metabolite']\n", "\n", " for x, D in df.iterrows():\n", - " M = D[\"Metabolite\"]\n", + " M = D['Metabolite']\n", " if m == M:\n", " iii += [i]\n", " xxx += [x]\n", "\n", - " if D[\"CCS\"]:\n", - " df_cas22_unique.at[i, \"CCS\"] = df_cas22_unique.at[i, \"CCS\"] + [\n", - " D[\"CCS\"]\n", + " if D['CCS']:\n", + " df_cas22_unique.at[i, 'CCS'] = df_cas22_unique.at[i, 'CCS'] + [\n", + " D['CCS']\n", " ] # Add CCS metadata\n", "\n", "iii = np.unique(iii)\n", "print(\n", - " f\"Found {len(iii)} instances violating test/train split (CASMI 22). Metabolite found in train/val set.\"\n", + " f'Found {len(iii)} instances violating test/train split (CASMI 22). Metabolite found in train/val set.'\n", ")\n", - "print(f\"Dropping {len(xxx)} spectra from training DataFrame.\")\n", + "print(f'Dropping {len(xxx)} spectra from training DataFrame.')\n", "df.drop(xxx, inplace=True)\n", "\n", "# Add CCS metadata\n", - "df_cas22_unique[\"CCS_std\"] = df_cas22_unique[\"CCS\"].apply(np.std)\n", - "df_cas22_unique[\"CCS\"] = df_cas22_unique[\"CCS\"].apply(np.mean)\n", - "df_cas22[\"CCS\"] = df_cas22[\"ChallengeName\"].apply(\n", - " lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS\"].iloc[0]\n", + "df_cas22_unique['CCS_std'] = df_cas22_unique['CCS'].apply(np.std)\n", + "df_cas22_unique['CCS'] = df_cas22_unique['CCS'].apply(np.mean)\n", + "df_cas22['CCS'] = df_cas22['ChallengeName'].apply(\n", + " lambda x: df_cas22_unique[df_cas22_unique['ChallengeName'] == x]['CCS'].iloc[0]\n", ")\n", - "df_cas22[\"CCS_std\"] = df_cas22[\"ChallengeName\"].apply(\n", - " lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS_std\"].iloc[0]\n", + "df_cas22['CCS_std'] = df_cas22['ChallengeName'].apply(\n", + " lambda x: df_cas22_unique[df_cas22_unique['ChallengeName'] == x]['CCS_std'].iloc[0]\n", ")" ] }, @@ -642,15 +642,15 @@ "metadata": {}, "outputs": [], "source": [ - "df_cas22_unique[~df_cas22_unique[\"CCS\"].isna()]\n", + "df_cas22_unique[~df_cas22_unique['CCS'].isna()]\n", "\n", "# TODO add those to df_cas22\n", "\n", - "df_cas22[\"CCS\"] = df_cas22[\"ChallengeName\"].apply(\n", - " lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS\"].iloc[0]\n", + "df_cas22['CCS'] = df_cas22['ChallengeName'].apply(\n", + " lambda x: df_cas22_unique[df_cas22_unique['ChallengeName'] == x]['CCS'].iloc[0]\n", ")\n", - "df_cas22[\"CCS_std\"] = df_cas22[\"ChallengeName\"].apply(\n", - " lambda x: df_cas22_unique[df_cas22_unique[\"ChallengeName\"] == x][\"CCS_std\"].iloc[0]\n", + "df_cas22['CCS_std'] = df_cas22['ChallengeName'].apply(\n", + " lambda x: df_cas22_unique[df_cas22_unique['ChallengeName'] == x]['CCS_std'].iloc[0]\n", ")" ] }, @@ -663,9 +663,9 @@ "# Save casmi with\n", "save_df = False\n", "if save_df:\n", - " df_cas.to_csv(f\"{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv\")\n", - " df_cast.to_csv(f\"{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv\")\n", - " df_cas22.to_csv(f\"{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv\")\n", + " df_cas.to_csv(f'{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv')\n", + " df_cast.to_csv(f'{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv')\n", + " df_cas22.to_csv(f'{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv')\n", "\n", " print(df_cas.head(3))\n", " print(df_cas22.head(3))" @@ -686,11 +686,11 @@ } ], "source": [ - "print(\"Assigning unique metabolite identifiers.\")\n", + "print('Assigning unique metabolite identifiers.')\n", "\n", "metabolite_id_map = {}\n", "\n", - "for metabolite in df[\"Metabolite\"]:\n", + "for metabolite in df['Metabolite']:\n", " is_new = True\n", " for id, other in metabolite_id_map.items():\n", " if metabolite == other:\n", @@ -702,12 +702,12 @@ " metabolite.id = new_id\n", " metabolite_id_map[new_id] = metabolite\n", "\n", - "df[\"group_id\"] = df[\"Metabolite\"].apply(lambda x: x.get_id())\n", - "df[\"num_per_group\"] = df[\"group_id\"].map(df[\"group_id\"].value_counts())\n", + "df['group_id'] = df['Metabolite'].apply(lambda x: x.get_id())\n", + "df['num_per_group'] = df['group_id'].map(df['group_id'].value_counts())\n", "\n", "for i, data in df.iterrows():\n", - " data[\"Metabolite\"].set_loss_weight(1.0 / data[\"num_per_group\"])\n", - "print(f\"Found {len(metabolite_id_map)} unique molecular structures.\")" + " data['Metabolite'].set_loss_weight(1.0 / data['num_per_group'])\n", + "print(f'Found {len(metabolite_id_map)} unique molecular structures.')" ] }, { @@ -730,7 +730,7 @@ } ], "source": [ - "df.groupby(\"lib\").group_id.unique().apply(len)" + "df.groupby('lib').group_id.unique().apply(len)" ] }, { @@ -759,7 +759,7 @@ "metadata": {}, "outputs": [], "source": [ - "df[\"loss_weight\"] = df[\"Metabolite\"].apply(lambda x: x.loss_weight)" + "df['loss_weight'] = df['Metabolite'].apply(lambda x: x.loss_weight)" ] }, { @@ -817,32 +817,31 @@ ], "source": [ "from fiora.MOL.mol_graph import draw_graph\n", + "from fiora.visualization.define_colors import *\n", "from fiora.visualization.define_colors import (\n", - " define_figure_style,\n", - " color_palette,\n", " bluepink,\n", " bluepink_grad,\n", " bluepink_grad8,\n", + " color_palette,\n", + " define_figure_style,\n", " tri_palette,\n", ")\n", - "from fiora.visualization.define_colors import *\n", - "import matplotlib.pyplot as plt\n", "\n", - "matplotlib.rcParams[\"figure.figsize\"] = (12, 6)\n", + "matplotlib.rcParams['figure.figsize'] = (12, 6)\n", "\n", - "magma_palette = define_figure_style(style=\"magma-white\", palette_steps=8)\n", + "magma_palette = define_figure_style(style='magma-white', palette_steps=8)\n", "\n", "if not test_run:\n", - " EXAMPLE_ID = 32271 if (lib == \"NIST\") else 7607 if lib == \"MSDIAL\" else 0\n", + " EXAMPLE_ID = 32271 if (lib == 'NIST') else 7607 if lib == 'MSDIAL' else 0\n", " example = df.loc[EXAMPLE_ID]\n", "\n", " fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 1]}, sharey=False\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 1]}, sharey=False\n", " )\n", " set_light_theme()\n", "\n", - " img = example[\"Metabolite\"].draw(ax=axs[0])\n", - " draw_graph(example[\"Metabolite\"].Graph, ax=axs[1])" + " img = example['Metabolite'].draw(ax=axs[0])\n", + " draw_graph(example['Metabolite'].Graph, ax=axs[1])" ] }, { @@ -864,18 +863,18 @@ "source": [ "if not test_run:\n", " fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", " )\n", "\n", - " img = example[\"Metabolite\"].draw(ax=axs[0])\n", + " img = example['Metabolite'].draw(ax=axs[0])\n", "\n", " axs[0].grid(False)\n", " axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", " )\n", - " axs[0].set_title(str(example[\"Metabolite\"]))\n", + " axs[0].set_title(str(example['Metabolite']))\n", " axs[0].imshow(img)\n", - " axs[0].axis(\"off\")\n", + " axs[0].axis('off')\n", " sv.plot_spectrum(example, ax=axs[1])" ] }, @@ -886,10 +885,10 @@ "outputs": [], "source": [ "%%capture\n", - "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=x['ppm_peak_tolerance']\n", " ),\n", " axis=1,\n", ")" @@ -927,7 +926,7 @@ ], "source": [ "example = df.iloc[10]\n", - "example[\"Metabolite\"].match_stats\n", + "example['Metabolite'].match_stats\n", "# raise KeyboardInterrupt()" ] }, @@ -948,13 +947,13 @@ } ], "source": [ - "df[\"np\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"num_peak_matches\"])\n", - "df[\"npf\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"num_peak_matches_filtered\"])\n", - "df[\"ppf\"] = df[\"Metabolite\"].apply(\n", - " lambda x: x.match_stats[\"percent_peak_matches_filtered\"]\n", + "df['np'] = df['Metabolite'].apply(lambda x: x.match_stats['num_peak_matches'])\n", + "df['npf'] = df['Metabolite'].apply(lambda x: x.match_stats['num_peak_matches_filtered'])\n", + "df['ppf'] = df['Metabolite'].apply(\n", + " lambda x: x.match_stats['percent_peak_matches_filtered']\n", ")\n", "\n", - "sns.violinplot(data=df, y=\"ppf\")\n", + "sns.violinplot(data=df, y='ppf')\n", "plt.show()" ] }, @@ -975,7 +974,7 @@ } ], "source": [ - "sum((df[\"ppf\"] > 0.5) & (df[\"npf\"] < 5) & (df[\"npf\"] > 2))" + "sum((df['ppf'] > 0.5) & (df['npf'] < 5) & (df['npf'] > 2))" ] }, { @@ -1006,7 +1005,7 @@ } ], "source": [ - "df[\"np\"].value_counts().head(10)" + "df['np'].value_counts().head(10)" ] }, { @@ -1037,7 +1036,7 @@ } ], "source": [ - "df[\"npf\"].value_counts().head(10)" + "df['npf'].value_counts().head(10)" ] }, { @@ -1093,21 +1092,21 @@ } ], "source": [ - "from fiora.MOL.constants import ADDUCT_WEIGHTS, PPM\n", + "from fiora.MOL.constants import ADDUCT_WEIGHTS\n", "\n", - "df[\"Precursor_offset\"] = df[\"PrecursorMZ\"] - df.apply(\n", - " lambda x: x[\"Metabolite\"].ExactMolWeight + ADDUCT_WEIGHTS[x[\"Precursor_type\"]],\n", + "df['Precursor_offset'] = df['PrecursorMZ'] - df.apply(\n", + " lambda x: x['Metabolite'].ExactMolWeight + ADDUCT_WEIGHTS[x['Precursor_type']],\n", " axis=1,\n", ")\n", - "df[\"Precursor_abs_error\"] = abs(df[\"Precursor_offset\"])\n", - "df[\"Precursor_rel_error\"] = df[\"Precursor_abs_error\"] / df[\"PrecursorMZ\"]\n", - "df[\"Precursor_ppm_error\"] = df[\"Precursor_abs_error\"] / (df[\"PrecursorMZ\"] * PPM)\n", + "df['Precursor_abs_error'] = abs(df['Precursor_offset'])\n", + "df['Precursor_rel_error'] = df['Precursor_abs_error'] / df['PrecursorMZ']\n", + "df['Precursor_ppm_error'] = df['Precursor_abs_error'] / (df['PrecursorMZ'] * PPM)\n", "print(\n", - " (df[\"Precursor_ppm_error\"] > df[\"PPM_num\"]).sum(),\n", - " \"found with misaligned precursor. Removing these.\",\n", + " (df['Precursor_ppm_error'] > df['PPM_num']).sum(),\n", + " 'found with misaligned precursor. Removing these.',\n", ")\n", "\n", - "df = df[df[\"Precursor_ppm_error\"] <= df[\"PPM_num\"]]" + "df = df[df['Precursor_ppm_error'] <= df['PPM_num']]" ] }, { @@ -1137,8 +1136,8 @@ } ], "source": [ - "df[\"coverage\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - "sns.kdeplot(data=df, x=\"coverage\", hue=\"lib\") # , multiple=\"dodge\")" + "df['coverage'] = df['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", + "sns.kdeplot(data=df, x='coverage', hue='lib') # , multiple=\"dodge\")" ] }, { @@ -1168,33 +1167,33 @@ "# TODO Implement conflict solver\n", "\n", "coverage_tracker = {\n", - " \"counts\": [],\n", - " \"all\": [],\n", - " \"coverage\": [],\n", - " \"fragment_only_coverage\": [],\n", - " \"Precursor_type\": [],\n", + " 'counts': [],\n", + " 'all': [],\n", + " 'coverage': [],\n", + " 'fragment_only_coverage': [],\n", + " 'Precursor_type': [],\n", "}\n", "\n", "drop_index = []\n", "for i, d in df.iterrows():\n", - " M = d[\"Metabolite\"]\n", + " M = d['Metabolite']\n", "\n", - " coverage_tracker[\"counts\"] += [M.match_stats[\"counts\"]]\n", - " coverage_tracker[\"all\"] += [M.match_stats[\"ms_all_counts\"]]\n", - " coverage_tracker[\"fragment_only_coverage\"] += [M.match_stats[\"coverage_wo_prec\"]]\n", - " coverage_tracker[\"coverage\"] += [M.match_stats[\"coverage\"]]\n", - " coverage_tracker[\"Precursor_type\"] += [M.metadata[\"precursor_mode\"]]\n", + " coverage_tracker['counts'] += [M.match_stats['counts']]\n", + " coverage_tracker['all'] += [M.match_stats['ms_all_counts']]\n", + " coverage_tracker['fragment_only_coverage'] += [M.match_stats['coverage_wo_prec']]\n", + " coverage_tracker['coverage'] += [M.match_stats['coverage']]\n", + " coverage_tracker['Precursor_type'] += [M.metadata['precursor_mode']]\n", "\n", " #\n", " # IMPORTANT: FILTER AND CLEAN DATA\n", " #\n", "\n", " min_coverage = 0.5\n", - " if M.match_stats[\"coverage\"] < min_coverage: # Filter if total coverage is too low\n", + " if M.match_stats['coverage'] < min_coverage: # Filter if total coverage is too low\n", " drop_index.append(i)\n", "\n", " min_peaks = 2\n", - " if M.match_stats[\"num_peak_matches_filtered\"] < min_peaks:\n", + " if M.match_stats['num_peak_matches_filtered'] < min_peaks:\n", " drop_index.append(i)\n", "\n", " # Either condition is enough to keep the spectrum\n", @@ -1202,9 +1201,9 @@ " desired_peak_percentage = 0.5\n", " extremly_high_coverage = 0.8\n", " if (\n", - " (M.match_stats[\"num_peak_matches_filtered\"] < desired_peaks)\n", - " & (M.match_stats[\"percent_peak_matches_filtered\"] < desired_peak_percentage)\n", - " & (M.match_stats[\"coverage\"] < extremly_high_coverage)\n", + " (M.match_stats['num_peak_matches_filtered'] < desired_peaks)\n", + " & (M.match_stats['percent_peak_matches_filtered'] < desired_peak_percentage)\n", + " & (M.match_stats['coverage'] < extremly_high_coverage)\n", " ):\n", " drop_index.append(i)\n", "\n", @@ -1216,7 +1215,7 @@ "\n", " max_precursor = 0.9\n", " if (\n", - " M.match_stats[\"precursor_prob\"] > max_precursor\n", + " M.match_stats['precursor_prob'] > max_precursor\n", " ): # Filter if fragment coverage is too low (intensity wise)\n", " drop_index.append(i)\n", "\n", @@ -1227,17 +1226,17 @@ "\n", "# filter low res instruments TODO update to low quality spectra\n", "low_quality_tags = [\n", - " \"IT/ion trap\",\n", - " \"QqQ\",\n", - " \"LC-ESI-QQ\",\n", - " \"Flow-injection QqQ/MS\",\n", - " \"LC-APPI-QQ\",\n", - " \"LC-ESI-IT\",\n", - " \"LC-ESI-QIT\",\n", - " \"QIT\",\n", + " 'IT/ion trap',\n", + " 'QqQ',\n", + " 'LC-ESI-QQ',\n", + " 'Flow-injection QqQ/MS',\n", + " 'LC-APPI-QQ',\n", + " 'LC-ESI-IT',\n", + " 'LC-ESI-QIT',\n", + " 'QIT',\n", "] # What about ESI-ITTOF? GC-APCI-QTOF?\n", - "low_res_machines = df[\"Metabolite\"].apply(\n", - " lambda x: x.metadata[\"instrument\"] in low_quality_tags\n", + "low_res_machines = df['Metabolite'].apply(\n", + " lambda x: x.metadata['instrument'] in low_quality_tags\n", ")\n", "drop_index += list(df[low_res_machines].index)\n", "\n", @@ -1245,18 +1244,18 @@ "\n", "plt.ylim([-0.02, 1.02])\n", "sns.boxplot(\n", - " ax=axs[0], data=coverage_tracker, y=\"fragment_only_coverage\", color=magma_palette[1]\n", + " ax=axs[0], data=coverage_tracker, y='fragment_only_coverage', color=magma_palette[1]\n", ")\n", - "sns.boxplot(ax=axs[1], data=coverage_tracker, y=\"coverage\", color=magma_palette[2])\n", - "sns.violinplot(ax=axs[2], data=coverage_tracker, y=\"coverage\", color=magma_palette[2])\n", - "axs[0].set_title(\"Coverage of peak intensity (fragments only)\")\n", - "axs[1].set_title(\"Coverage of peak intensity\")\n", - "axs[2].set_title(\"Coverage of peak intensity\")\n", - "axs[2].axhline(y=min_coverage, color=\"black\", linestyle=\"--\", label=\"Horizontal Line\")\n", + "sns.boxplot(ax=axs[1], data=coverage_tracker, y='coverage', color=magma_palette[2])\n", + "sns.violinplot(ax=axs[2], data=coverage_tracker, y='coverage', color=magma_palette[2])\n", + "axs[0].set_title('Coverage of peak intensity (fragments only)')\n", + "axs[1].set_title('Coverage of peak intensity')\n", + "axs[2].set_title('Coverage of peak intensity')\n", + "axs[2].axhline(y=min_coverage, color='black', linestyle='--', label='Horizontal Line')\n", "plt.show()\n", "\n", "print(\n", - " f\"Filtering out {len(drop_index)} that have only precursor matches || or || too little (intensity) coverage to make edge prediction possible\"\n", + " f'Filtering out {len(drop_index)} that have only precursor matches || or || too little (intensity) coverage to make edge prediction possible'\n", ")\n", "df.drop(drop_index, inplace=True)" ] @@ -1297,8 +1296,8 @@ "source": [ "def custom_sample(group):\n", " sample_size = 20\n", - " x_origin = group[group[\"lib\"] == \"NIST\"]\n", - " y_origin = group[group[\"lib\"] == \"MSDIAL\"]\n", + " x_origin = group[group['lib'] == 'NIST']\n", + " y_origin = group[group['lib'] == 'MSDIAL']\n", "\n", " if len(x_origin) >= sample_size:\n", " return x_origin.sample(n=sample_size, replace=False)\n", @@ -1315,13 +1314,13 @@ "\n", "# df = df.reset_index(drop=True)\n", "\n", - "num_structures = len(df[\"group_id\"].unique())\n", - "print(f\"Train and validate on {num_structures} unique structures\")\n", - "sns.histplot(df[\"group_id\"].value_counts(), stat=\"probability\", binwidth=1)\n", + "num_structures = len(df['group_id'].unique())\n", + "print(f'Train and validate on {num_structures} unique structures')\n", + "sns.histplot(df['group_id'].value_counts(), stat='probability', binwidth=1)\n", "plt.xlim([0, 25])\n", "plt.show()\n", "\n", - "sns.histplot(df, x=\"origin\", stat=\"count\")\n", + "sns.histplot(df, x='origin', stat='count')\n", "plt.xticks(rotation=90)\n", "plt.show()" ] @@ -1346,7 +1345,7 @@ } ], "source": [ - "df.groupby(\"lib\").group_id.unique().apply(len)" + "df.groupby('lib').group_id.unique().apply(len)" ] }, { @@ -1372,15 +1371,16 @@ ], "source": [ "import torch\n", - "from fiora.GNN.Trainer import Trainer\n", "import torch_geometric as geom\n", "\n", + "from fiora.GNN.Trainer import Trainer\n", + "\n", "if torch.cuda.is_available():\n", - " dev = \"cuda:0\"\n", + " dev = 'cuda:0'\n", "else:\n", - " dev = \"cpu\"\n", + " dev = 'cpu'\n", "\n", - "print(f\"Running on device: {dev}\")" + "print(f'Running on device: {dev}')" ] }, { @@ -1428,14 +1428,14 @@ } ], "source": [ - "df[\"pp\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"precursor_prob\"])\n", - "top_instrumenttypes = df[\"Instrument_type\"].value_counts().head(3).index\n", + "df['pp'] = df['Metabolite'].apply(lambda x: x.match_stats['precursor_prob'])\n", + "top_instrumenttypes = df['Instrument_type'].value_counts().head(3).index\n", "f, a = plt.subplots(1, 1, figsize=(8, 4))\n", "sns.histplot(\n", - " data=df[df[\"Instrument_type\"].isin(top_instrumenttypes)],\n", - " x=\"pp\",\n", - " hue=\"Instrument_type\",\n", - " multiple=\"dodge\",\n", + " data=df[df['Instrument_type'].isin(top_instrumenttypes)],\n", + " x='pp',\n", + " hue='Instrument_type',\n", + " multiple='dodge',\n", ")\n", "plt.show()" ] @@ -1454,8 +1454,8 @@ } ], "source": [ - "geo_data = df[\"Metabolite\"].apply(lambda x: x.as_geometric_data().to(dev)).values\n", - "print(f\"Prepared training/validation with {len(geo_data)} data points\")" + "geo_data = df['Metabolite'].apply(lambda x: x.as_geometric_data().to(dev)).values\n", + "print(f'Prepared training/validation with {len(geo_data)} data points')" ] }, { @@ -1510,27 +1510,27 @@ "outputs": [], "source": [ "model_params = {\n", - " \"gnn_type\": \"RGCNConv\",\n", - " \"depth\": 3,\n", - " \"hidden_dimension\": 300,\n", - " \"dense_layers\": 2,\n", - " \"embedding_aggregation\": \"concat\",\n", - " \"embedding_dimension\": 300,\n", - " \"input_dropout\": 0.2,\n", - " \"latent_dropout\": 0.1,\n", - " \"node_feature_layout\": node_encoder.feature_numbers,\n", - " \"edge_feature_layout\": bond_encoder.feature_numbers,\n", - " \"static_feature_dimension\": geo_data[0][\"static_edge_features\"].shape[1],\n", - " \"static_rt_feature_dimension\": geo_data[0][\"static_rt_features\"].shape[1],\n", - " \"output_dimension\": len(DEFAULT_MODES) * 2, # per edge\n", + " 'gnn_type': 'RGCNConv',\n", + " 'depth': 3,\n", + " 'hidden_dimension': 300,\n", + " 'dense_layers': 2,\n", + " 'embedding_aggregation': 'concat',\n", + " 'embedding_dimension': 300,\n", + " 'input_dropout': 0.2,\n", + " 'latent_dropout': 0.1,\n", + " 'node_feature_layout': node_encoder.feature_numbers,\n", + " 'edge_feature_layout': bond_encoder.feature_numbers,\n", + " 'static_feature_dimension': geo_data[0]['static_edge_features'].shape[1],\n", + " 'static_rt_feature_dimension': geo_data[0]['static_rt_features'].shape[1],\n", + " 'output_dimension': len(DEFAULT_MODES) * 2, # per edge\n", "}\n", "training_params = {\n", - " \"epochs\": 200 if not test_run else 20, # 180,\n", - " \"batch_size\": 256, # 128,\n", - " \"train_val_split\": 0.90,\n", - " \"learning_rate\": 0.0004, # 0.001,\n", - " \"with_RT\": False, # TODO CHANGED\n", - " \"with_CCS\": False,\n", + " 'epochs': 200 if not test_run else 20, # 180,\n", + " 'batch_size': 256, # 128,\n", + " 'train_val_split': 0.90,\n", + " 'learning_rate': 0.0004, # 0.001,\n", + " 'with_RT': False, # TODO CHANGED\n", + " 'with_CCS': False,\n", "}" ] }, @@ -1541,6 +1541,7 @@ "outputs": [], "source": [ "from fiora.GNN.GNNModules import GNNCompiler\n", + "\n", "from fiora.GNN.Losses import WeightedMSELoss, WeightedMSEMetric\n", "\n", "model = GNNCompiler(model_params).to(dev)" @@ -1589,11 +1590,11 @@ "\n", "# Make sure that the example is in the test split\n", "if not test_run:\n", - " ex_smiles = \"CC(NC(=O)CC1=CNC2=C1C=CC=C2)C(O)=O\"\n", + " ex_smiles = 'CC(NC(=O)CC1=CNC2=C1C=CC=C2)C(O)=O'\n", " ex_metabolite = Metabolite(ex_smiles)\n", - " ex_compound_id = df[df[\"Metabolite\"] == ex_metabolite][\"group_id\"].iloc[0]\n", + " ex_compound_id = df[df['Metabolite'] == ex_metabolite]['group_id'].iloc[0]\n", "\n", - "group_ids = df[\"group_id\"].astype(int)\n", + "group_ids = df['group_id'].astype(int)\n", "keys = np.unique(group_ids)\n", "example_not_in_test_split = True\n", "\n", @@ -1601,18 +1602,18 @@ " train, val, test = train_val_test_split(keys, rseed=seed + i)\n", " if test_run or (ex_compound_id in test):\n", " print(\n", - " f\"Seed {seed + i} used to sample slits, such that the example Metabolite is in the test set.\"\n", + " f'Seed {seed + i} used to sample slits, such that the example Metabolite is in the test set.'\n", " )\n", " break\n", - "df[\"dataset\"] = df[\"group_id\"].apply(\n", + "df['dataset'] = df['group_id'].apply(\n", " lambda x: (\n", - " \"train\"\n", + " 'train'\n", " if x in train\n", - " else \"validation\"\n", + " else 'validation'\n", " if x in val\n", - " else \"test\"\n", + " else 'test'\n", " if x in test\n", - " else \"VALUE ERROR\"\n", + " else 'VALUE ERROR'\n", " )\n", ")" ] @@ -1631,27 +1632,27 @@ } ], "source": [ - "y_label = \"compiled_probsALL\"\n", + "y_label = 'compiled_probsALL'\n", "train_keys, val_keys = (\n", - " df[df[\"dataset\"] == \"train\"][\"group_id\"].unique(),\n", - " df[df[\"dataset\"] == \"validation\"][\"group_id\"].unique(),\n", + " df[df['dataset'] == 'train']['group_id'].unique(),\n", + " df[df['dataset'] == 'validation']['group_id'].unique(),\n", ")\n", "\n", "trainer = Trainer(\n", " geo_data,\n", " y_tag=y_label,\n", - " problem_type=\"regression\",\n", - " metric_dict={\"mse\": WeightedMSEMetric},\n", + " problem_type='regression',\n", + " metric_dict={'mse': WeightedMSEMetric},\n", " train_keys=train_keys,\n", " val_keys=val_keys,\n", " split_by_group=True,\n", " seed=seed,\n", " device=dev,\n", ")\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=training_params[\"learning_rate\"])\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=training_params['learning_rate'])\n", "# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n", "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", - " optimizer, patience=8, factor=0.5, mode=\"min\", verbose=True\n", + " optimizer, patience=8, factor=0.5, mode='min', verbose=True\n", ")\n", "\n", "loss_fn = WeightedMSELoss()\n", @@ -1884,17 +1885,17 @@ } ], "source": [ - "tag = \"b1\"\n", + "tag = 'b1'\n", "checkpoints = trainer.train(\n", " model,\n", " optimizer,\n", " loss_fn,\n", " scheduler=scheduler,\n", - " batch_size=training_params[\"batch_size\"],\n", - " epochs=training_params[\"epochs\"],\n", + " batch_size=training_params['batch_size'],\n", + " epochs=training_params['epochs'],\n", " val_every_n_epochs=1,\n", - " with_CCS=training_params[\"with_CCS\"],\n", - " with_RT=training_params[\"with_RT\"],\n", + " with_CCS=training_params['with_CCS'],\n", + " with_RT=training_params['with_RT'],\n", " masked_validation=False,\n", " tag=tag,\n", ") # , mask_name=\"compiled_validation_maskALL\")" @@ -1925,10 +1926,10 @@ "from_checkpoint = True\n", "\n", "if from_checkpoint:\n", - " print(f\"Loading model from checkpoint {checkpoints}.\")\n", + " print(f'Loading model from checkpoint {checkpoints}.')\n", " end_model = GNNCompiler(model_params).to(dev)\n", " end_model = end_model.load_state_dict(model.state_dict())\n", - " model = model.load(checkpoints[\"file\"]).to(dev)" + " model = model.load(checkpoints['file']).to(dev)" ] }, { @@ -1956,21 +1957,21 @@ "source": [ "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", - "df[\"dataset\"] = df[\"group_id\"].apply(\n", + "df['dataset'] = df['group_id'].apply(\n", " lambda x: (\n", - " \"training\"\n", + " 'training'\n", " if trainer.is_group_in_training_set(x)\n", - " else \"validation\"\n", + " else 'validation'\n", " if trainer.is_group_in_validation_set(x)\n", - " else \"test\"\n", + " else 'test'\n", " )\n", ")\n", "\n", "fiora = SimulationFramework(\n", " model,\n", " dev=dev,\n", - " with_RT=training_params[\"with_RT\"],\n", - " with_CCS=training_params[\"with_CCS\"],\n", + " with_RT=training_params['with_RT'],\n", + " with_CCS=training_params['with_CCS'],\n", ")\n", "df = fiora.simulate_all(df, model)" ] @@ -2017,25 +2018,25 @@ ], "source": [ "reset_matplotlib()\n", - "for i, data in df[df[\"Metabolite\"] == ex_metabolite].iterrows():\n", - " cosine = data[\"spectral_sqrt_cosine\"]\n", - " name = data[\"Name\"]\n", - " print(f\"{name} ({i}): cosine {cosine:0.2}\")\n", + "for i, data in df[df['Metabolite'] == ex_metabolite].iterrows():\n", + " cosine = data['spectral_sqrt_cosine']\n", + " name = data['Name']\n", + " print(f'{name} ({i}): cosine {cosine:0.2}')\n", " fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", " )\n", - " img = data[\"Metabolite\"].draw(ax=axs[0])\n", + " img = data['Metabolite'].draw(ax=axs[0])\n", "\n", " # axs[0].grid(False)\n", " axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", " )\n", - " axs[0].set_title(data[\"Name\"])\n", + " axs[0].set_title(data['Name'])\n", " # axs[0].imshow(img)\n", " # axs[0].axis(\"off\")\n", " # sv.plot_spectrum(example, ax=axs[1])\n", " ax = sv.plot_spectrum(\n", - " data, {\"peaks\": data[\"sim_peaks\"]}, ax=axs[1], highlight_matches=False\n", + " data, {'peaks': data['sim_peaks']}, ax=axs[1], highlight_matches=False\n", " )\n", " plt.show()" ] @@ -2151,39 +2152,39 @@ "source": [ "df[\n", " [\n", - " \"cosine_similarity\",\n", - " \"kl_div\",\n", - " \"spectral_cosine\",\n", - " \"spectral_sqrt_cosine\",\n", - " \"spectral_refl_cosine\",\n", - " \"spectral_bias\",\n", - " \"spectral_sqrt_bias\",\n", - " \"spectral_refl_bias\",\n", + " 'cosine_similarity',\n", + " 'kl_div',\n", + " 'spectral_cosine',\n", + " 'spectral_sqrt_cosine',\n", + " 'spectral_refl_cosine',\n", + " 'spectral_bias',\n", + " 'spectral_sqrt_bias',\n", + " 'spectral_refl_bias',\n", " ]\n", "] = df[\n", " [\n", - " \"cosine_similarity\",\n", - " \"kl_div\",\n", - " \"spectral_cosine\",\n", - " \"spectral_sqrt_cosine\",\n", - " \"spectral_refl_cosine\",\n", - " \"spectral_bias\",\n", - " \"spectral_sqrt_bias\",\n", - " \"spectral_refl_bias\",\n", + " 'cosine_similarity',\n", + " 'kl_div',\n", + " 'spectral_cosine',\n", + " 'spectral_sqrt_cosine',\n", + " 'spectral_refl_cosine',\n", + " 'spectral_bias',\n", + " 'spectral_sqrt_bias',\n", + " 'spectral_refl_bias',\n", " ]\n", "].astype(float)\n", "# df[\"is_peptide\"] = df[\"Notes\"].apply(lambda x: \"Peptide\" in x) if lib == \"NIST\" else False\n", "\n", - "df_train = df[df[\"dataset\"] == \"training\"]\n", - "df_val = df[df[\"dataset\"] == \"validation\"]\n", - "df_val[\"library\"] = \"Validation\"\n", + "df_train = df[df['dataset'] == 'training']\n", + "df_val = df[df['dataset'] == 'validation']\n", + "df_val['library'] = 'Validation'\n", "\n", - "for key in example[\"Metabolite\"].match_stats.keys():\n", - " df_val[key] = df[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n", + "for key in example['Metabolite'].match_stats.keys():\n", + " df_val[key] = df['Metabolite'].apply(lambda x: x.match_stats[key])\n", "\n", - "df_val[\"ring_proportion\"] = df[\"Metabolite\"].apply(\n", + "df_val['ring_proportion'] = df['Metabolite'].apply(\n", " lambda x: (\n", - " getattr(x, \"is_edge_in_ring\").sum() / getattr(x, \"is_edge_in_ring\").shape[0]\n", + " getattr(x, 'is_edge_in_ring').sum() / getattr(x, 'is_edge_in_ring').shape[0]\n", " ).tolist()\n", ")" ] @@ -2241,14 +2242,14 @@ "source": [ "fig, axs = plt.subplots(1, 4, figsize=(19.2, 8.6), sharey=False)\n", "\n", - "sns.boxplot(ax=axs[0], data=df, y=\"spectral_cosine\", x=\"dataset\", palette=bluepink)\n", - "sns.boxplot(ax=axs[1], data=df, y=\"spectral_sqrt_cosine\", x=\"dataset\", palette=bluepink)\n", - "sns.boxplot(ax=axs[2], data=df, y=\"spectral_refl_cosine\", x=\"dataset\", palette=bluepink)\n", - "sns.boxplot(ax=axs[3], data=df, y=\"steins_cosine\", x=\"dataset\", palette=bluepink)\n", - "axs[0].set_title(\"Spectral cosine similarity\")\n", - "axs[1].set_title(\"Spectral cosine similarity of sqrt intensities\")\n", - "axs[2].set_title(\"Spectral cosine similarity of simulated peaks\")\n", - "axs[3].set_title(\"Steins dot product (reweighted with mass)\")\n", + "sns.boxplot(ax=axs[0], data=df, y='spectral_cosine', x='dataset', palette=bluepink)\n", + "sns.boxplot(ax=axs[1], data=df, y='spectral_sqrt_cosine', x='dataset', palette=bluepink)\n", + "sns.boxplot(ax=axs[2], data=df, y='spectral_refl_cosine', x='dataset', palette=bluepink)\n", + "sns.boxplot(ax=axs[3], data=df, y='steins_cosine', x='dataset', palette=bluepink)\n", + "axs[0].set_title('Spectral cosine similarity')\n", + "axs[1].set_title('Spectral cosine similarity of sqrt intensities')\n", + "axs[2].set_title('Spectral cosine similarity of simulated peaks')\n", + "axs[3].set_title('Steins dot product (reweighted with mass)')\n", "plt.show()" ] }, @@ -2269,7 +2270,7 @@ } ], "source": [ - "sns.boxplot(data=df_val, x=\"origin\", y=\"spectral_sqrt_cosine\")\n", + "sns.boxplot(data=df_val, x='origin', y='spectral_sqrt_cosine')\n", "plt.xticks(rotation=90)\n", "plt.show()" ] @@ -2427,20 +2428,20 @@ } ], "source": [ - "print(\"Stats at first glance (training and validation)\")\n", + "print('Stats at first glance (training and validation)')\n", "keys = [\n", - " \"cosine_similarity\",\n", - " \"kl_div\",\n", - " \"spectral_cosine\",\n", - " \"spectral_sqrt_cosine\",\n", - " \"spectral_refl_cosine\",\n", + " 'cosine_similarity',\n", + " 'kl_div',\n", + " 'spectral_cosine',\n", + " 'spectral_sqrt_cosine',\n", + " 'spectral_refl_cosine',\n", "]\n", "\n", "for key in keys:\n", - " blue = PRINT_COL[\"blue\"]\n", - " end = PRINT_COL[\"end\"]\n", + " blue = PRINT_COL['blue']\n", + " end = PRINT_COL['end']\n", " print(\n", - " f\"Median {key}: \\t{df_train[key].median():.2f} {blue} {df_val[key].median():.2f} {end}\"\n", + " f'Median {key}: \\t{df_train[key].median():.2f} {blue} {df_val[key].median():.2f} {end}'\n", " )" ] }, @@ -2479,19 +2480,19 @@ "fig, axs = plt.subplots(1, 4, figsize=(18, 6), sharey=False)\n", "\n", "sns.boxplot(\n", - " ax=axs[0], data=df, y=\"spectral_cosine\", hue=\"dataset\", palette=bluepink[:3]\n", + " ax=axs[0], data=df, y='spectral_cosine', hue='dataset', palette=bluepink[:3]\n", ")\n", "sns.boxplot(\n", - " ax=axs[1], data=df, y=\"spectral_sqrt_cosine\", hue=\"dataset\", palette=bluepink[:3]\n", + " ax=axs[1], data=df, y='spectral_sqrt_cosine', hue='dataset', palette=bluepink[:3]\n", ")\n", "sns.boxplot(\n", - " ax=axs[2], data=df, y=\"spectral_refl_cosine\", hue=\"dataset\", palette=bluepink[:3]\n", + " ax=axs[2], data=df, y='spectral_refl_cosine', hue='dataset', palette=bluepink[:3]\n", ")\n", - "sns.boxplot(ax=axs[3], data=df, y=\"steins_cosine\", hue=\"dataset\", palette=bluepink[:3])\n", - "axs[0].set_title(\"Spectral cosine similarity\")\n", - "axs[1].set_title(\"Spectral cosine similarity of sqrt intensities\")\n", - "axs[2].set_title(\"Spectral cosine similarity of simulated peaks\")\n", - "axs[3].set_title(\"Steins dot product (reweighted with mass)\")\n", + "sns.boxplot(ax=axs[3], data=df, y='steins_cosine', hue='dataset', palette=bluepink[:3])\n", + "axs[0].set_title('Spectral cosine similarity')\n", + "axs[1].set_title('Spectral cosine similarity of sqrt intensities')\n", + "axs[2].set_title('Spectral cosine similarity of simulated peaks')\n", + "axs[3].set_title('Steins dot product (reweighted with mass)')\n", "plt.show()" ] }, @@ -2514,18 +2515,18 @@ "source": [ "fig, axs = plt.subplots(1, 4, figsize=(18, 6), sharey=False)\n", "\n", - "sns.boxplot(ax=axs[0], data=df, y=\"spectral_bias\", hue=\"dataset\", palette=bluepink[:3])\n", + "sns.boxplot(ax=axs[0], data=df, y='spectral_bias', hue='dataset', palette=bluepink[:3])\n", "sns.boxplot(\n", - " ax=axs[1], data=df, y=\"spectral_sqrt_bias\", hue=\"dataset\", palette=bluepink[:3]\n", + " ax=axs[1], data=df, y='spectral_sqrt_bias', hue='dataset', palette=bluepink[:3]\n", ")\n", "sns.boxplot(\n", - " ax=axs[2], data=df, y=\"spectral_refl_bias\", hue=\"dataset\", palette=bluepink[:3]\n", + " ax=axs[2], data=df, y='spectral_refl_bias', hue='dataset', palette=bluepink[:3]\n", ")\n", - "sns.boxplot(ax=axs[3], data=df, y=\"steins_bias\", hue=\"dataset\", palette=bluepink[:3])\n", - "axs[0].set_title(\"Spectral cosine bias\")\n", - "axs[1].set_title(\"Spectral cosine bias of sqrt intensities\")\n", - "axs[2].set_title(\"Spectral cosine bias of simulated peaks\")\n", - "axs[3].set_title(\"Steins dot bias (reweighted with mass)\")\n", + "sns.boxplot(ax=axs[3], data=df, y='steins_bias', hue='dataset', palette=bluepink[:3])\n", + "axs[0].set_title('Spectral cosine bias')\n", + "axs[1].set_title('Spectral cosine bias of sqrt intensities')\n", + "axs[2].set_title('Spectral cosine bias of simulated peaks')\n", + "axs[3].set_title('Steins dot bias (reweighted with mass)')\n", "plt.show()" ] }, @@ -2560,11 +2561,11 @@ "sns.boxplot(\n", " ax=ax,\n", " data=df_val,\n", - " y=\"spectral_cosine\",\n", - " hue=\"Instrument_type\",\n", + " y='spectral_cosine',\n", + " hue='Instrument_type',\n", " palette=bluepink_grad8,\n", ")\n", - "axs[0].set_title(\"Cosine\")\n", + "axs[0].set_title('Cosine')\n", "plt.show()" ] }, @@ -2607,40 +2608,40 @@ "outputs": [], "source": [ "%%capture\n", - "df_cas[\"RETENTIONTIME\"] = df_cas[\"RTINSECONDS\"] / 60.0\n", - "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cas['RETENTIONTIME'] = df_cas['RTINSECONDS'] / 60.0\n", + "df_cas['Metabolite'] = df_cas['SMILES'].apply(Metabolite)\n", + "df_cas['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", "\n", "if filter_spectra:\n", - " setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - " setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - "df_cas[\"Metabolite\"].apply(\n", + " setup_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + " setup_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + "df_cas['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "df_cas['CE'] = 20.0 # actually stepped 20/35/50\n", + "df_cas['Instrument_type'] = 'HCD' # CHECK if correct Orbitrap\n", "\n", "metadata_key_map = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'PRECURSOR_MZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", "}\n", "\n", - "df_cas[\"summary\"] = df_cas.apply(\n", + "df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", ")\n", "df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder), axis=1\n", ")\n", "\n", "# Fragmentation\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cas['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=100 * PPM\n", " ),\n", " axis=1,\n", ") # Optional: use mz_cut instead" @@ -2655,37 +2656,37 @@ "%%capture\n", "from fiora.MOL.collision_energy import NCE_to_eV\n", "\n", - "df_cast[\"dataset\"] = \"CASMI 16 Training\"\n", - "df_cast[\"RETENTIONTIME\"] = df_cast[\"RTINSECONDS\"] / 60.0\n", - "df_cast[\"Metabolite\"] = df_cast[\"SMILES\"].apply(Metabolite)\n", - "df_cast[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cast['dataset'] = 'CASMI 16 Training'\n", + "df_cast['RETENTIONTIME'] = df_cast['RTINSECONDS'] / 60.0\n", + "df_cast['Metabolite'] = df_cast['SMILES'].apply(Metabolite)\n", + "df_cast['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cast[\"Metabolite\"].apply(\n", + "df_cast['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cast[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cast[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "df_cast['CE'] = 20.0 # actually stepped 20/35/50\n", + "df_cast['Instrument_type'] = 'HCD' # CHECK if correct Orbitrap\n", "\n", "metadata_key_map16 = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'PRECURSOR_MZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", "}\n", "\n", - "df_cast[\"summary\"] = df_cast.apply(\n", + "df_cast['summary'] = df_cast.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", ")\n", "df_cast.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder), axis=1\n", ")\n", "\n", "# Fragmentation\n", - "df_cast[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cast['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_cast.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=100 * PPM\n", " ),\n", " axis=1,\n", ")" @@ -2707,7 +2708,7 @@ "source": [ "### CHECK IF train test split is correct\n", "\n", - "print(\"Skipping extra check. As struture disjoint testing is enforced on top.\")\n", + "print('Skipping extra check. As struture disjoint testing is enforced on top.')\n", "# iii = []\n", "# xxx = []\n", "# for i,d in df_cas.iterrows():\n", @@ -2742,18 +2743,18 @@ ], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "\n", - "img = df_cas.loc[0][\"Metabolite\"].draw(ax=axs[0])\n", + "img = df_cas.loc[0]['Metabolite'].draw(ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(str(df_cas.loc[0][\"Metabolite\"]))\n", + "axs[0].set_title(str(df_cas.loc[0]['Metabolite']))\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", + "axs[0].axis('off')\n", "sv.plot_spectrum(df_cas.loc[0], ax=axs[1])" ] }, @@ -2840,94 +2841,93 @@ } ], "source": [ - "from fiora.MOL.collision_energy import NCE_to_eV\n", + "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "from fiora.MS.spectral_scores import (\n", + " reweighted_dot,\n", " spectral_cosine,\n", " spectral_reflection_cosine,\n", - " reweighted_dot,\n", ")\n", - "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", "\n", "def test_cas(df_cas):\n", - " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 20.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step1_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_20')\n", "\n", - " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 35.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step2_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_35')\n", "\n", - " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 50.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step3_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_50')\n", "\n", - " df_cas[\"avg_CE\"] = (\n", - " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " df_cas['avg_CE'] = (\n", + " df_cas['step1_CE'] + df_cas['step2_CE'] + df_cas['step3_CE']\n", " ) / 3\n", "\n", - " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " df_cas['merged_peaks'] = df_cas.apply(\n", " lambda x: merge_annotated_spectrum(\n", - " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", - " x[\"sim_peaks_50\"],\n", + " merge_annotated_spectrum(x['sim_peaks_20'], x['sim_peaks_35']),\n", + " x['sim_peaks_50'],\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['merged_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " df_cas['merged_sqrt_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks'], transform=np.sqrt),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " df_cas['merged_refl_cosine'] = df_cas.apply(\n", " lambda x: spectral_reflection_cosine(\n", - " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " x['peaks'], x['merged_peaks'], transform=np.sqrt\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_steins\"] = df_cas.apply(\n", - " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['merged_steins'] = df_cas.apply(\n", + " lambda x: reweighted_dot(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", - " \"merged_sqrt_cosine\"\n", + " df_cas['spectral_sqrt_cosine'] = df_cas[\n", + " 'merged_sqrt_cosine'\n", " ] # just remember it is merged\n", "\n", - " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - " df_cas[\"RT_pred\"] = df_cas[\"RT_pred_35\"]\n", - " df_cas[\"RT_dif\"] = df_cas[\"RT_dif_35\"]\n", - " df_cas[\"CCS_pred\"] = df_cas[\"CCS_pred_35\"]\n", - " df_cas[\"library\"] = \"CASMI-16\"\n", + " df_cas['coverage'] = df_cas['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", + " df_cas['RT_pred'] = df_cas['RT_pred_35']\n", + " df_cas['RT_dif'] = df_cas['RT_dif_35']\n", + " df_cas['CCS_pred'] = df_cas['CCS_pred_35']\n", + " df_cas['library'] = 'CASMI-16'\n", "\n", " return df_cas\n", "\n", @@ -2976,23 +2976,23 @@ "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", "sns.histplot(\n", " df_cas,\n", - " x=\"avg_CE\",\n", - " hue=\"Precursor_type\",\n", - " multiple=\"stack\",\n", - " palette=[\"black\", \"gray\"],\n", + " x='avg_CE',\n", + " hue='Precursor_type',\n", + " multiple='stack',\n", + " palette=['black', 'gray'],\n", ") # bluepink[:2][::-1])\n", - "plt.xlabel(\"Average collision energy\")\n", + "plt.xlabel('Average collision energy')\n", "plt.show()\n", "\n", "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", "sns.histplot(\n", " df_cast,\n", - " x=\"avg_CE\",\n", - " hue=\"Precursor_type\",\n", - " multiple=\"stack\",\n", - " palette=[\"black\", \"gray\"],\n", + " x='avg_CE',\n", + " hue='Precursor_type',\n", + " multiple='stack',\n", + " palette=['black', 'gray'],\n", ") # bluepink[:2][::-1])\n", - "plt.xlabel(\"Average collision energy\")\n", + "plt.xlabel('Average collision energy')\n", "plt.show()" ] }, @@ -3055,24 +3055,24 @@ ], "source": [ "i = 3\n", - "print(df_cas.loc[i][\"merged_sqrt_cosine\"])\n", + "print(df_cas.loc[i]['merged_sqrt_cosine'])\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", - "img = df_cas.loc[i][\"Metabolite\"].draw(ax=axs[0])\n", + "img = df_cas.loc[i]['Metabolite'].draw(ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(df_cas.loc[i][\"NAME\"])\n", + "axs[0].set_title(df_cas.loc[i]['NAME'])\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", + "axs[0].axis('off')\n", "# sv.plot_spectrum(example, ax=axs[1])\n", "ax = sv.plot_spectrum(\n", - " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", - " {\"peaks\": df_cas.loc[i][\"sim_peaks_35\"]},\n", + " {'peaks': df_cas.loc[i]['peaks']},\n", + " {'peaks': df_cas.loc[i]['sim_peaks_35']},\n", " ax=axs[1],\n", ")\n", "# axs[1].text(0.5, 0.5, 'matplotlib', horizontalalignment='center', verticalalignment='center', transform=axs[1].transAxes)\n", @@ -3113,11 +3113,11 @@ "sns.boxplot(\n", " ax=ax,\n", " data=df_cas,\n", - " y=\"merged_sqrt_cosine\",\n", - " x=\"Precursor_type\",\n", + " y='merged_sqrt_cosine',\n", + " x='Precursor_type',\n", " palette=bluepink[:2][::-1],\n", ")\n", - "ax.set_title(\"Spectral cosine similarity of sqrt intensities\")\n", + "ax.set_title('Spectral cosine similarity of sqrt intensities')\n", "plt.show()" ] }, @@ -3142,7 +3142,7 @@ " return str(ref[i])\n", "\n", "\n", - "df_cas[\"cfm_CE\"] = df_cas[\"avg_CE\"].apply(closest_cfm_ce)" + "df_cas['cfm_CE'] = df_cas['avg_CE'].apply(closest_cfm_ce)" ] }, { @@ -3160,13 +3160,14 @@ "outputs": [], "source": [ "import fiora.IO.cfmReader as cfmReader\n", + "\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", "cf = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi16_negative_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi16_negative_predictions.txt', as_df=True\n", ")\n", "cf_p = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt', as_df=True\n", ")\n", "cf = pd.concat([cf, cf_p])" ] @@ -3188,7 +3189,7 @@ } ], "source": [ - "len(cf[cf[\"#ID\"] == \"Challenge-009\"]) ## missing chalenges" + "len(cf[cf['#ID'] == 'Challenge-009']) ## missing chalenges" ] }, { @@ -3216,28 +3217,28 @@ } ], "source": [ - "df_cas[\"cfm_peaks\"] = None\n", - "df_cas[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"cfm_refl_cosine\"]] = np.nan\n", + "df_cas['cfm_peaks'] = None\n", + "df_cas[['cfm_cosine', 'cfm_sqrt_cosine', 'cfm_refl_cosine']] = np.nan\n", "for i, cas in df_cas.iterrows():\n", - " challenge = cas[\"ChallengeName\"]\n", + " challenge = cas['ChallengeName']\n", "\n", - " if len(cf[cf[\"#ID\"] == challenge]) != 1:\n", - " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", + " if len(cf[cf['#ID'] == challenge]) != 1:\n", + " print(f'{challenge} not found in CFM-ID results. Skipping.')\n", " continue\n", - " cfm_data = cf[cf[\"#ID\"] == challenge].iloc[0]\n", - "\n", - " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", - " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", - " df_cas.at[i, \"cfm_peaks\"] = cfm_peaks\n", - " df_cas.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cas.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " cfm_data = cf[cf['#ID'] == challenge].iloc[0]\n", + "\n", + " if cas['ChallengeName'] != cfm_data['#ID']:\n", + " raise ValueError('Wrong challenge matched')\n", + " cfm_peaks = cfm_data['peaks' + cas['cfm_CE']] # find best reference CE\n", + " df_cas.at[i, 'cfm_peaks'] = cfm_peaks\n", + " df_cas.at[i, 'cfm_cosine'] = spectral_cosine(cas['peaks'], cfm_peaks)\n", + " df_cas.at[i, 'cfm_sqrt_cosine'] = spectral_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cas.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " df_cas.at[i, 'cfm_refl_cosine'] = spectral_reflection_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cas.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)" + " df_cas.at[i, 'cfm_steins'] = reweighted_dot(cas['peaks'], cfm_peaks)" ] }, { @@ -3258,37 +3259,37 @@ ], "source": [ "cosines = {\n", - " \"cosine\": df_cas[\"merged_cosine\"],\n", - " \"sqrt_cosine\": df_cas[\"merged_sqrt_cosine\"],\n", - " \"refl_cosine\": df_cas[\"merged_refl_cosine\"],\n", - " \"steins_dot\": df_cas[\"merged_steins\"],\n", - " \"model\": \"Fiora\",\n", + " 'cosine': df_cas['merged_cosine'],\n", + " 'sqrt_cosine': df_cas['merged_sqrt_cosine'],\n", + " 'refl_cosine': df_cas['merged_refl_cosine'],\n", + " 'steins_dot': df_cas['merged_steins'],\n", + " 'model': 'Fiora',\n", "}\n", "cosines2 = {\n", - " \"cosine\": df_cas[\"cfm_cosine\"],\n", - " \"sqrt_cosine\": df_cas[\"cfm_sqrt_cosine\"],\n", - " \"refl_cosine\": df_cas[\"cfm_refl_cosine\"],\n", - " \"steins_dot\": df_cas[\"cfm_steins\"],\n", - " \"model\": \"CFM-ID 4.4.7\",\n", + " 'cosine': df_cas['cfm_cosine'],\n", + " 'sqrt_cosine': df_cas['cfm_sqrt_cosine'],\n", + " 'refl_cosine': df_cas['cfm_refl_cosine'],\n", + " 'steins_dot': df_cas['cfm_steins'],\n", + " 'model': 'CFM-ID 4.4.7',\n", "}\n", "fig, axs = plt.subplots(\n", - " 1, 3, figsize=(18.8, 4.2), gridspec_kw={\"width_ratios\": [1, 1, 1]}, sharey=False\n", + " 1, 3, figsize=(18.8, 4.2), gridspec_kw={'width_ratios': [1, 1, 1]}, sharey=False\n", ")\n", "\n", "scores = pd.concat([pd.DataFrame(cosines), pd.DataFrame(cosines2)], ignore_index=True)\n", "sns.histplot(\n", - " scores, ax=axs[0], x=\"cosine\", hue=\"model\", palette=bluepink[:2], binwidth=0.025\n", + " scores, ax=axs[0], x='cosine', hue='model', palette=bluepink[:2], binwidth=0.025\n", ")\n", "sns.histplot(\n", " scores,\n", " ax=axs[1],\n", - " x=\"sqrt_cosine\",\n", - " hue=\"model\",\n", + " x='sqrt_cosine',\n", + " hue='model',\n", " palette=bluepink[:2],\n", " binwidth=0.025,\n", ")\n", "sns.histplot(\n", - " scores, ax=axs[2], x=\"steins_dot\", hue=\"model\", palette=bluepink[:2], binwidth=0.025\n", + " scores, ax=axs[2], x='steins_dot', hue='model', palette=bluepink[:2], binwidth=0.025\n", ")\n", "plt.show()" ] @@ -3334,9 +3335,6 @@ " l.set_xdata([xmin_new, xmax_new])\n", "\n", "\n", - "from matplotlib.patches import PathPatch\n", - "\n", - "\n", "def adjust_box_widths(fig, fac):\n", " \"\"\"\n", " Adjust the widths of a seaborn-generated boxplot.\n", @@ -3369,8 +3367,8 @@ " if len(l.get_xdata()) > 0:\n", " # check if the line is a median line\n", " if (\n", - " \"color\" in l.properties()\n", - " and l.properties()[\"color\"] == \"black\"\n", + " 'color' in l.properties()\n", + " and l.properties()['color'] == 'black'\n", " ):\n", " l.set_xdata([xmin_new, xmax_new])" ] @@ -3420,7 +3418,7 @@ } ], "source": [ - "df_cas.groupby(\"Precursor_type\")[\"merged_sqrt_cosine\"].median()" + "df_cas.groupby('Precursor_type')['merged_sqrt_cosine'].median()" ] }, { @@ -3439,9 +3437,9 @@ } ], "source": [ - "print(\"Mean:\\t\", round(df_cas[\"merged_sqrt_cosine\"].mean(), 2)) # 0.634\n", - "print(\"Median:\\t\", round(df_cas[\"merged_sqrt_cosine\"].median(), 2)) # 0.737\n", - "print(\"Var:\\t\", round(df_cas[\"merged_sqrt_cosine\"].var(), 2)) # 0.116" + "print('Mean:\\t', round(df_cas['merged_sqrt_cosine'].mean(), 2)) # 0.634\n", + "print('Median:\\t', round(df_cas['merged_sqrt_cosine'].median(), 2)) # 0.737\n", + "print('Var:\\t', round(df_cas['merged_sqrt_cosine'].var(), 2)) # 0.116" ] }, { @@ -3460,9 +3458,9 @@ } ], "source": [ - "print(\"Mean:\\t\", round(df_cas[\"cfm_sqrt_cosine\"].mean(), 2))\n", - "print(\"Median:\\t\", round(df_cas[\"cfm_sqrt_cosine\"].median(), 2))\n", - "print(\"Var:\\t\", round(df_cas[\"cfm_sqrt_cosine\"].var(), 2))" + "print('Mean:\\t', round(df_cas['cfm_sqrt_cosine'].mean(), 2))\n", + "print('Median:\\t', round(df_cas['cfm_sqrt_cosine'].median(), 2))\n", + "print('Var:\\t', round(df_cas['cfm_sqrt_cosine'].var(), 2))" ] }, { @@ -3479,12 +3477,12 @@ } ], "source": [ - "df_cas[\"higher_cosine\"] = (df_cas[\"merged_sqrt_cosine\"] - df_cas[\"cfm_sqrt_cosine\"]) > 0\n", - "df_cas[\"smaller_cosine\"] = (\n", - " df_cas[\"merged_sqrt_cosine\"] - df_cas[\"cfm_sqrt_cosine\"]\n", + "df_cas['higher_cosine'] = (df_cas['merged_sqrt_cosine'] - df_cas['cfm_sqrt_cosine']) > 0\n", + "df_cas['smaller_cosine'] = (\n", + " df_cas['merged_sqrt_cosine'] - df_cas['cfm_sqrt_cosine']\n", ") < 0\n", - "h, l = sum(df_cas[\"higher_cosine\"]), sum(df_cas[\"smaller_cosine\"])\n", - "print(f\"Higher in {h} of cases (smaller in {l} cases) out of {df_cas.shape[0]}\")" + "h, l = sum(df_cas['higher_cosine']), sum(df_cas['smaller_cosine'])\n", + "print(f'Higher in {h} of cases (smaller in {l} cases) out of {df_cas.shape[0]}')" ] }, { @@ -3501,10 +3499,10 @@ } ], "source": [ - "df_cas[\"higher_cosine\"] = (df_cas[\"merged_steins\"] - df_cas[\"cfm_steins\"]) > 0\n", - "df_cas[\"smaller_cosine\"] = (df_cas[\"merged_steins\"] - df_cas[\"cfm_steins\"]) < 0\n", - "h, l = sum(df_cas[\"higher_cosine\"]), sum(df_cas[\"smaller_cosine\"])\n", - "print(f\"Higher in {h} of cases (smaller in {l} cases) out of {df_cas.shape[0]}\")" + "df_cas['higher_cosine'] = (df_cas['merged_steins'] - df_cas['cfm_steins']) > 0\n", + "df_cas['smaller_cosine'] = (df_cas['merged_steins'] - df_cas['cfm_steins']) < 0\n", + "h, l = sum(df_cas['higher_cosine']), sum(df_cas['smaller_cosine'])\n", + "print(f'Higher in {h} of cases (smaller in {l} cases) out of {df_cas.shape[0]}')" ] }, { @@ -3522,11 +3520,12 @@ "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import seaborn as sns\n", "import spectrum_utils.plot as sup\n", "import spectrum_utils.spectrum as sus\n", - "import seaborn as sns\n", - "from pyteomics import pylab_aux as pa, usi\n", - "import pandas as pd" + "from pyteomics import pylab_aux as pa\n", + "from pyteomics import usi" ] }, { @@ -3538,53 +3537,53 @@ "import matplotlib.patches as mpatches\n", "\n", "\n", - "def double_mirrorplot(i, model_title=\"Fiora\"):\n", + "def double_mirrorplot(i, model_title='Fiora'):\n", " fig, axs = plt.subplots(\n", - " 1, 3, figsize=(16.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3, 3]}, sharey=False\n", + " 1, 3, figsize=(16.8, 4.2), gridspec_kw={'width_ratios': [1, 3, 3]}, sharey=False\n", " )\n", "\n", " plt.subplots_adjust(right=0.975, left=0.025)\n", "\n", - " img = df_cas.loc[i][\"Metabolite\"].draw(ax=axs[0])\n", + " img = df_cas.loc[i]['Metabolite'].draw(ax=axs[0])\n", "\n", " axs[0].grid(False)\n", " axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", " )\n", " axs[0].set_title(\n", - " df_cas.loc[i][\"NAME\"] + \"\\n(\" + df_cas.loc[i][\"ChallengeName\"] + \")\"\n", + " df_cas.loc[i]['NAME'] + '\\n(' + df_cas.loc[i]['ChallengeName'] + ')'\n", " )\n", " axs[0].imshow(img)\n", - " axs[0].axis(\"off\")\n", + " axs[0].axis('off')\n", "\n", " sv.plot_spectrum(\n", - " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", - " {\"peaks\": df_cas.loc[i][\"merged_peaks\"]},\n", + " {'peaks': df_cas.loc[i]['peaks']},\n", + " {'peaks': df_cas.loc[i]['merged_peaks']},\n", " ax=axs[1],\n", " )\n", " axs[1].title.set_text(model_title)\n", " patch1 = mpatches.Patch(\n", - " color=\"limegreen\"\n", - " if df_cas.loc[i][\"cfm_sqrt_cosine\"] < df_cas.loc[i][\"merged_sqrt_cosine\"]\n", - " else \"orangered\",\n", - " label=f\"cosine {df_cas.loc[i]['merged_sqrt_cosine']:.02f}\",\n", + " color='limegreen'\n", + " if df_cas.loc[i]['cfm_sqrt_cosine'] < df_cas.loc[i]['merged_sqrt_cosine']\n", + " else 'orangered',\n", + " label=f'cosine {df_cas.loc[i][\"merged_sqrt_cosine\"]:.02f}',\n", " )\n", " axs[1].legend(handles=[patch1])\n", "\n", " sv.plot_spectrum(\n", - " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", - " {\"peaks\": df_cas.loc[i][\"cfm_peaks\"]}\n", - " if df_cas.loc[i][\"cfm_peaks\"]\n", - " else {\"peaks\": {\"mz\": [0], \"intensity\": [0]}},\n", + " {'peaks': df_cas.loc[i]['peaks']},\n", + " {'peaks': df_cas.loc[i]['cfm_peaks']}\n", + " if df_cas.loc[i]['cfm_peaks']\n", + " else {'peaks': {'mz': [0], 'intensity': [0]}},\n", " ax=axs[2],\n", " )\n", - " axs[2].title.set_text(f\"CFM-ID 4.4.7\")\n", + " axs[2].title.set_text('CFM-ID 4.4.7')\n", "\n", " patch2 = mpatches.Patch(\n", - " color=\"limegreen\"\n", - " if df_cas.loc[i][\"cfm_sqrt_cosine\"] > df_cas.loc[i][\"merged_sqrt_cosine\"]\n", - " else \"orangered\",\n", - " label=f\"cosine {df_cas.loc[i]['cfm_sqrt_cosine']:.02f}\",\n", + " color='limegreen'\n", + " if df_cas.loc[i]['cfm_sqrt_cosine'] > df_cas.loc[i]['merged_sqrt_cosine']\n", + " else 'orangered',\n", + " label=f'cosine {df_cas.loc[i][\"cfm_sqrt_cosine\"]:.02f}',\n", " )\n", " axs[2].legend(handles=[patch2])\n", "\n", @@ -3622,7 +3621,7 @@ } ], "source": [ - "test_obj = df_cas.iloc[1][\"Metabolite\"]\n", + "test_obj = df_cas.iloc[1]['Metabolite']\n", "test_obj.draw()\n", "plt.show()" ] @@ -3673,8 +3672,8 @@ " for i in thecaseforfioracase:\n", " fig, axs = double_mirrorplot(i)\n", " plt.savefig(\n", - " f\"{home}/images/\" + df_cas.at[i, \"ChallengeName\"] + \"_mirror.svg\",\n", - " format=\"svg\",\n", + " f'{home}/images/' + df_cas.at[i, 'ChallengeName'] + '_mirror.svg',\n", + " format='svg',\n", " )\n", " plt.clf()" ] @@ -3931,8 +3930,8 @@ "# df_cas22[col] = df_cas22[col].apply(ast.literal_eval)\n", "\n", "print(df_cas22.shape)\n", - "df_cas22[\"ChallengeNum\"] = df_cas22[\"ChallengeName\"].apply(\n", - " lambda x: int(x.split(\"-\")[-1])\n", + "df_cas22['ChallengeNum'] = df_cas22['ChallengeName'].apply(\n", + " lambda x: int(x.split('-')[-1])\n", ")\n", "try:\n", " df_cas22.reset_index(inplace=True)\n", @@ -3948,37 +3947,37 @@ "outputs": [], "source": [ "%%capture\n", - "df_cas22[\"Metabolite\"] = df_cas22[\"SMILES\"].apply(Metabolite)\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cas22['Metabolite'] = df_cas22['SMILES'].apply(Metabolite)\n", + "df_cas22['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas22[\"Metabolite\"].apply(\n", + "df_cas22['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cas22[\"CE\"] = df_cas22.apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1\n", + "df_cas22['CE'] = df_cas22.apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['precursor_mz']), axis=1\n", ")\n", "\n", "metadata_key_map = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"precursor_mz\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"ChallengeRT\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'precursor_mz',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'ChallengeRT',\n", "}\n", "\n", - "df_cas22[\"summary\"] = df_cas22.apply(\n", + "df_cas22['summary'] = df_cas22.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1\n", ")\n", "df_cas22.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", ")\n", "\n", "# Fragmentation\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cas22['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_cas22.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=100 * PPM\n", " ),\n", " axis=1,\n", ") # Optional: use mz_cut instead" @@ -4003,9 +4002,9 @@ } ], "source": [ - "df_cas22_unique = df_cas22.drop_duplicates(subset=\"ChallengeName\", keep=\"first\")\n", + "df_cas22_unique = df_cas22.drop_duplicates(subset='ChallengeName', keep='first')\n", "# df_cas22_unique.reset_index(inplace=True)\n", - "df_cas22_unique[\"Metabolite\"] = df_cas22_unique[\"SMILES\"].apply(Metabolite)" + "df_cas22_unique['Metabolite'] = df_cas22_unique['SMILES'].apply(Metabolite)" ] }, { @@ -4025,7 +4024,7 @@ "### CHECK IF train test split is correct\n", "\n", "\n", - "print(\"Skipping extra check. As struture disjoint testing is enforced on top.\")\n", + "print('Skipping extra check. As struture disjoint testing is enforced on top.')\n", "# iii = []\n", "# xxx = []\n", "\n", @@ -4062,18 +4061,18 @@ ], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "\n", - "img = df_cas22.loc[0][\"Metabolite\"].draw(ax=axs[0])\n", + "img = df_cas22.loc[0]['Metabolite'].draw(ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(str(df_cas22.loc[0][\"Metabolite\"]))\n", + "axs[0].set_title(str(df_cas22.loc[0]['Metabolite']))\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", + "axs[0].axis('off')\n", "sv.plot_spectrum(df_cas22.loc[0], ax=axs[1])" ] }, @@ -4135,24 +4134,24 @@ "sns.histplot(\n", " ax=axs[0],\n", " data=df_cas22,\n", - " x=\"CE\",\n", - " hue=\"Precursor_type\",\n", - " multiple=\"stack\",\n", - " palette=[\"black\", \"gray\"],\n", + " x='CE',\n", + " hue='Precursor_type',\n", + " multiple='stack',\n", + " palette=['black', 'gray'],\n", ") # bluepink[:2][::-1])\n", - "axs[0].vlines(CE_upper_limit, 0, 120, color=\"red\")\n", - "axs[0].set_xlabel(\"Collision energy\")\n", + "axs[0].vlines(CE_upper_limit, 0, 120, color='red')\n", + "axs[0].set_xlabel('Collision energy')\n", "\n", "sns.histplot(\n", " ax=axs[1],\n", " data=df_cas22,\n", - " x=\"precursor_mz\",\n", - " hue=\"Precursor_type\",\n", - " multiple=\"stack\",\n", - " palette=[\"black\", \"gray\"],\n", + " x='precursor_mz',\n", + " hue='Precursor_type',\n", + " multiple='stack',\n", + " palette=['black', 'gray'],\n", ") # bluepink[:2][::-1])\n", "# axs[1].vlines(800, 0, 120, color=\"red\")\n", - "axs[1].set_xlabel(\"Precursor mz\")\n", + "axs[1].set_xlabel('Precursor mz')\n", "plt.show()" ] }, @@ -4162,7 +4161,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_cas22 = df_cas22[df_cas22[\"CE\"] <= CE_upper_limit]\n", + "df_cas22 = df_cas22[df_cas22['CE'] <= CE_upper_limit]\n", "# df_cas22 = df_cas22[df_cas22[\"precursor_mz\"] <= 800.0]" ] }, @@ -4172,10 +4171,10 @@ "metadata": {}, "outputs": [], "source": [ - "df_cas22[\"coverage\"] = df_cas22[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - "df_cas22[\"ring_proportion\"] = df_cas22[\"Metabolite\"].apply(\n", + "df_cas22['coverage'] = df_cas22['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", + "df_cas22['ring_proportion'] = df_cas22['Metabolite'].apply(\n", " lambda x: (\n", - " getattr(x, \"is_edge_in_ring\").sum() / getattr(x, \"is_edge_in_ring\").shape[0]\n", + " getattr(x, 'is_edge_in_ring').sum() / getattr(x, 'is_edge_in_ring').shape[0]\n", " ).tolist()\n", ")" ] @@ -4189,18 +4188,18 @@ "sns.boxplot(\n", " ax=axs[0],\n", " data=df_cas22,\n", - " y=\"spectral_sqrt_cosine\",\n", - " x=\"Precursor_type\",\n", - " hue=\"NCE\",\n", + " y='spectral_sqrt_cosine',\n", + " x='Precursor_type',\n", + " hue='NCE',\n", " palette=bluepink_grad8[-3:][::-1],\n", ")\n", - "axs[0].set_title(\"Spectral cosine similarity of sqrt intensities\")\n", + "axs[0].set_title('Spectral cosine similarity of sqrt intensities')\n", "sns.scatterplot(\n", " ax=axs[1],\n", " data=df_cas22,\n", - " x=\"coverage\",\n", - " y=\"spectral_sqrt_cosine\",\n", - " hue=\"spectral_sqrt_cosine\",\n", + " x='coverage',\n", + " y='spectral_sqrt_cosine',\n", + " hue='spectral_sqrt_cosine',\n", " hue_norm=(0, 1),\n", " palette=bluepink_grad,\n", ")\n", @@ -4232,24 +4231,24 @@ ], "source": [ "i = 3\n", - "print(df_cas22.loc[i][\"spectral_sqrt_cosine\"])\n", + "print(df_cas22.loc[i]['spectral_sqrt_cosine'])\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", - "img = df_cas22.loc[i][\"Metabolite\"].draw(ax=axs[0])\n", + "img = df_cas22.loc[i]['Metabolite'].draw(ax=axs[0])\n", "\n", "axs[0].grid(False)\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", - "axs[0].set_title(df_cas.loc[i][\"NAME\"])\n", + "axs[0].set_title(df_cas.loc[i]['NAME'])\n", "axs[0].imshow(img)\n", - "axs[0].axis(\"off\")\n", + "axs[0].axis('off')\n", "# sv.plot_spectrum(example, ax=axs[1])\n", "ax = sv.plot_spectrum(\n", - " {\"peaks\": df_cas22.loc[i][\"peaks\"]},\n", - " {\"peaks\": df_cas22.loc[i][\"sim_peaks\"]},\n", + " {'peaks': df_cas22.loc[i]['peaks']},\n", + " {'peaks': df_cas22.loc[i]['sim_peaks']},\n", " ax=axs[1],\n", ")\n", "# axs[1].text(0.5, 0.5, 'matplotlib', horizontalalignment='center', verticalalignment='center', transform=axs[1].transAxes)\n", @@ -4274,35 +4273,35 @@ ], "source": [ "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(12, 14), sharex=True, gridspec_kw={\"height_ratios\": [1, 5]}\n", + " 2, 1, figsize=(12, 14), sharex=True, gridspec_kw={'height_ratios': [1, 5]}\n", ")\n", "plt.subplots_adjust(hspace=0.05) # right=0.975, left=0.11)\n", "# sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", " ax=axs[0],\n", " data=df_cas22,\n", - " x=\"coverage\",\n", + " x='coverage',\n", " bw_adjust=0.2,\n", - " color=\"black\",\n", - " multiple=\"stack\",\n", - " hue=\"Precursor_type\",\n", - " palette=[\"black\", \"gray\"],\n", + " color='black',\n", + " multiple='stack',\n", + " hue='Precursor_type',\n", + " palette=['black', 'gray'],\n", ") # hue=\"Precursor_type\",\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "\n", - "axs[0].set_title(\"Impact of coverage on cosine scores\")\n", + "axs[0].set_title('Impact of coverage on cosine scores')\n", "sns.scatterplot(\n", " ax=axs[1],\n", " data=df_cas22,\n", - " x=\"coverage\",\n", - " y=\"spectral_sqrt_cosine\",\n", - " hue=\"spectral_sqrt_cosine\",\n", + " x='coverage',\n", + " y='spectral_sqrt_cosine',\n", + " hue='spectral_sqrt_cosine',\n", " hue_norm=(0, 1),\n", " palette=bluepink_grad,\n", ")\n", - "axs[1].set_ylabel(\"Cosine similarity\")\n", - "axs[1].set_xlabel(\"Peak intensity coverage\")\n", + "axs[1].set_ylabel('Cosine similarity')\n", + "axs[1].set_xlabel('Peak intensity coverage')\n", "plt.show()" ] }, @@ -4448,9 +4447,9 @@ "metadata": {}, "outputs": [], "source": [ - "df_cas22[\"library\"] = \"CASMI-22\"\n", - "df_cas22[\"RETENTIONTIME\"] = df_cas22[\"ChallengeRT\"] # \"RT_min\"\n", - "df_cas22[\"cfm_CE\"] = df_cas22[\"CE\"].apply(closest_cfm_ce)" + "df_cas22['library'] = 'CASMI-22'\n", + "df_cas22['RETENTIONTIME'] = df_cas22['ChallengeRT'] # \"RT_min\"\n", + "df_cas22['cfm_CE'] = df_cas22['CE'].apply(closest_cfm_ce)" ] }, { @@ -4460,14 +4459,15 @@ "outputs": [], "source": [ "import fiora.IO.cfmReader as cfmReader\n", + "\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", "\n", "cf22 = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi22_negative_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi22_negative_predictions.txt', as_df=True\n", ")\n", "cf22_p = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi22_positive_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi22_positive_predictions.txt', as_df=True\n", ")\n", "cf22 = pd.concat([cf22, cf22_p])" ] @@ -4487,28 +4487,28 @@ } ], "source": [ - "df_cas22[\"cfm_peaks\"] = None\n", - "df_cas22[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"cfm_refl_cosine\"]] = np.nan\n", + "df_cas22['cfm_peaks'] = None\n", + "df_cas22[['cfm_cosine', 'cfm_sqrt_cosine', 'cfm_refl_cosine']] = np.nan\n", "for i, cas in df_cas22.iterrows():\n", - " challenge = cas[\"ChallengeName\"]\n", + " challenge = cas['ChallengeName']\n", "\n", - " if len(cf22[cf22[\"#ID\"] == challenge]) != 1:\n", - " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", + " if len(cf22[cf22['#ID'] == challenge]) != 1:\n", + " print(f'{challenge} not found in CFM-ID results. Skipping.')\n", " continue\n", - " cfm_data = cf22[cf22[\"#ID\"] == challenge].iloc[0]\n", - "\n", - " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", - " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", - " df_cas22.at[i, \"cfm_peaks\"] = cfm_peaks\n", - " df_cas22.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cas22.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " cfm_data = cf22[cf22['#ID'] == challenge].iloc[0]\n", + "\n", + " if cas['ChallengeName'] != cfm_data['#ID']:\n", + " raise ValueError('Wrong challenge matched')\n", + " cfm_peaks = cfm_data['peaks' + cas['cfm_CE']] # find best reference CE\n", + " df_cas22.at[i, 'cfm_peaks'] = cfm_peaks\n", + " df_cas22.at[i, 'cfm_cosine'] = spectral_cosine(cas['peaks'], cfm_peaks)\n", + " df_cas22.at[i, 'cfm_sqrt_cosine'] = spectral_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cas22.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " df_cas22.at[i, 'cfm_refl_cosine'] = spectral_reflection_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cas22.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)" + " df_cas22.at[i, 'cfm_steins'] = reweighted_dot(cas['peaks'], cfm_peaks)" ] }, { @@ -4540,18 +4540,18 @@ "sns.boxplot(\n", " ax=axs[0],\n", " data=df_cas22,\n", - " y=\"spectral_sqrt_cosine\",\n", - " x=\"Precursor_type\",\n", - " hue=\"NCE\",\n", + " y='spectral_sqrt_cosine',\n", + " x='Precursor_type',\n", + " hue='NCE',\n", " palette=bluepink_grad8[-3:][::-1],\n", ")\n", - "axs[0].set_title(\"Spectral cosine similarity of sqrt intensities\")\n", + "axs[0].set_title('Spectral cosine similarity of sqrt intensities')\n", "sns.boxplot(\n", " ax=axs[1],\n", " data=df_cas22,\n", - " y=\"cfm_sqrt_cosine\",\n", - " x=\"Precursor_type\",\n", - " hue=\"NCE\",\n", + " y='cfm_sqrt_cosine',\n", + " x='Precursor_type',\n", + " hue='NCE',\n", " palette=bluepink_grad8,\n", ")\n", "\n", @@ -4577,20 +4577,20 @@ "source": [ "dgnn = pd.DataFrame(\n", " {\n", - " \"cosine\": df_cas22[\"spectral_sqrt_cosine\"],\n", - " \"refl_cosine\": df_cas22[\"spectral_refl_cosine\"],\n", - " \"ion_mode\": df_cas22[\"Precursor_type\"],\n", - " \"challenge_num\": df_cas22[\"ChallengeNum\"],\n", - " \"model\": \"GNN-based fragmentation\",\n", + " 'cosine': df_cas22['spectral_sqrt_cosine'],\n", + " 'refl_cosine': df_cas22['spectral_refl_cosine'],\n", + " 'ion_mode': df_cas22['Precursor_type'],\n", + " 'challenge_num': df_cas22['ChallengeNum'],\n", + " 'model': 'GNN-based fragmentation',\n", " }\n", ")\n", "dcfm = pd.DataFrame(\n", " {\n", - " \"cosine\": df_cas22[\"cfm_sqrt_cosine\"],\n", - " \"refl_cosine\": df_cas22[\"cfm_refl_cosine\"],\n", - " \"ion_mode\": df_cas22[\"Precursor_type\"],\n", - " \"challenge_num\": df_cas22[\"ChallengeNum\"],\n", - " \"model\": \"CFM-ID 4.4.7\",\n", + " 'cosine': df_cas22['cfm_sqrt_cosine'],\n", + " 'refl_cosine': df_cas22['cfm_refl_cosine'],\n", + " 'ion_mode': df_cas22['Precursor_type'],\n", + " 'challenge_num': df_cas22['ChallengeNum'],\n", + " 'model': 'CFM-ID 4.4.7',\n", " }\n", ")\n", "\n", @@ -4603,16 +4603,16 @@ "sns.boxplot(\n", " ax=ax,\n", " data=D,\n", - " y=\"cosine\",\n", - " x=\"ion_mode\",\n", - " hue=\"model\",\n", + " y='cosine',\n", + " x='ion_mode',\n", + " hue='model',\n", " palette=bluepink[:2],\n", - " order=[\"[M+H]+\", \"[M-H]-\"],\n", + " order=['[M+H]+', '[M-H]-'],\n", ")\n", "# sns.boxplot(ax=axs[0], data=D, y=\"cosine\", x=\"model\", hue=\"ion_mode\", palette=bluepink[:1] * 2 + bluepink[:1] * 2)\n", - "ax.set_title(\"Spectral cosine similarity of CASMI-22 predictions\")\n", - "plt.xlabel(\"\")\n", - "plt.legend(loc=\"lower right\")\n", + "ax.set_title('Spectral cosine similarity of CASMI-22 predictions')\n", + "plt.xlabel('')\n", + "plt.legend(loc='lower right')\n", "plt.subplots_adjust(right=0.975, left=0.11) # TODO FIX error\n", "adjust_box_widths(fig, 0.95)\n", "# set_all_font_sizes(18)\n", @@ -4642,15 +4642,15 @@ "sns.boxplot(\n", " ax=ax,\n", " data=D,\n", - " y=\"refl_cosine\",\n", - " x=\"ion_mode\",\n", - " hue=\"model\",\n", + " y='refl_cosine',\n", + " x='ion_mode',\n", + " hue='model',\n", " palette=bluepink[:2],\n", - " order=[\"[M+H]+\", \"[M-H]-\"],\n", + " order=['[M+H]+', '[M-H]-'],\n", ")\n", - "ax.set_title(\"Spectral cosine similarity of CASMI-22 predictions\")\n", - "plt.xlabel(\"\")\n", - "plt.legend(loc=\"lower right\")\n", + "ax.set_title('Spectral cosine similarity of CASMI-22 predictions')\n", + "plt.xlabel('')\n", + "plt.legend(loc='lower right')\n", "plt.subplots_adjust(right=0.975, left=0.11)\n", "adjust_box_widths(fig, 0.95)\n", "plt.show()" @@ -4675,9 +4675,9 @@ } ], "source": [ - "df_cas[\"test_set\"] = \"CASMI-16\"\n", - "df_cas22[\"test_set\"] = \"CASMI-22\"\n", - "df_val[\"test_set\"] = \"Validation\"" + "df_cas['test_set'] = 'CASMI-16'\n", + "df_cas22['test_set'] = 'CASMI-22'\n", + "df_val['test_set'] = 'Validation'" ] }, { @@ -4754,20 +4754,20 @@ } ], "source": [ - "score = \"spectral_sqrt_cosine\"\n", + "score = 'spectral_sqrt_cosine'\n", "fiora_res = {\n", - " \"model\": \"Fiora\",\n", - " \"CASMI16\": np.median(df_cas[score.replace(\"spectral\", \"merged\")]),\n", - " \"CASMI22\": np.median(df_cas22[score]),\n", + " 'model': 'Fiora',\n", + " 'CASMI16': np.median(df_cas[score.replace('spectral', 'merged')]),\n", + " 'CASMI22': np.median(df_cas22[score]),\n", "}\n", "cfm_id = {\n", - " \"model\": \"CFM-ID 4.4.7\",\n", - " \"CASMI16\": np.median(df_cas[score.replace(\"spectral\", \"cfm\")].dropna()),\n", - " \"CASMI22\": np.median(df_cas22[score.replace(\"spectral\", \"cfm\")]),\n", + " 'model': 'CFM-ID 4.4.7',\n", + " 'CASMI16': np.median(df_cas[score.replace('spectral', 'cfm')].dropna()),\n", + " 'CASMI22': np.median(df_cas22[score.replace('spectral', 'cfm')]),\n", "}\n", "\n", "summary = pd.DataFrame([fiora_res, cfm_id])\n", - "print(\"Summary test sets\")\n", + "print('Summary test sets')\n", "summary" ] }, @@ -4850,56 +4850,56 @@ } ], "source": [ - "score = \"spectral_sqrt_cosine\"\n", + "score = 'spectral_sqrt_cosine'\n", "fiora_res = {\n", - " \"model\": \"Fiora\",\n", - " \"CASMI16+\": np.median(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"merged\")\n", + " 'model': 'Fiora',\n", + " 'CASMI16+': np.median(\n", + " df_cas[df_cas['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'merged')\n", " ]\n", " ),\n", - " \"CASMI16-\": np.median(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"merged\")\n", + " 'CASMI16-': np.median(\n", + " df_cas[df_cas['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'merged')\n", " ]\n", " ),\n", - " \"CASMI16T+\": np.median(\n", - " df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"merged\")\n", + " 'CASMI16T+': np.median(\n", + " df_cast[df_cast['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'merged')\n", " ]\n", " ),\n", - " \"CASMI16T-\": np.median(\n", - " df_cast[df_cast[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"merged\")\n", + " 'CASMI16T-': np.median(\n", + " df_cast[df_cast['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'merged')\n", " ]\n", " ),\n", - " \"CASMI22+\": np.median(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score]),\n", - " \"CASMI22-\": np.median(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score]),\n", + " 'CASMI22+': np.median(df_cas22[df_cas22['Precursor_type'] == '[M+H]+'][score]),\n", + " 'CASMI22-': np.median(df_cas22[df_cas22['Precursor_type'] == '[M-H]-'][score]),\n", "}\n", "cfm_id = {\n", - " \"model\": \"CFM-ID 4.4.7\",\n", - " \"CASMI16+\": np.median(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score.replace(\"spectral\", \"cfm\")]\n", + " 'model': 'CFM-ID 4.4.7',\n", + " 'CASMI16+': np.median(\n", + " df_cas[df_cas['Precursor_type'] == '[M+H]+'][score.replace('spectral', 'cfm')]\n", " ),\n", - " \"CASMI16-\": np.median(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI16-': np.median(\n", + " df_cas[df_cas['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ].dropna()\n", " ),\n", - " \"CASMI22+\": np.median(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI22+': np.median(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'cfm')\n", " ]\n", " ),\n", - " \"CASMI22-\": np.median(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI22-': np.median(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ]\n", " ),\n", "}\n", "\n", "summaryPos = pd.DataFrame([fiora_res, cfm_id])\n", - "print(\"Summary test sets\")\n", + "print('Summary test sets')\n", "summaryPos" ] }, @@ -4930,27 +4930,27 @@ "source": [ "raise KeyboardInterrupt()\n", "\n", - "df_cas.loc[:, \"tanimoto\"] = np.nan\n", + "df_cas.loc[:, 'tanimoto'] = np.nan\n", "for i, d in df_cas.iterrows():\n", - " df_cas.at[i, \"tanimoto\"] = (\n", - " df_train[\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " df_cas.at[i, 'tanimoto'] = (\n", + " df_train['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite']))\n", " .max()\n", " )\n", "\n", - "df_cas22.loc[:, \"tanimoto\"] = np.nan\n", + "df_cas22.loc[:, 'tanimoto'] = np.nan\n", "for i, d in df_cas22.iterrows():\n", - " df_cas22.at[i, \"tanimoto\"] = (\n", - " df_train[\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " df_cas22.at[i, 'tanimoto'] = (\n", + " df_train['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite']))\n", " .max()\n", " )\n", "\n", - "df_val.loc[:, \"tanimoto\"] = np.nan\n", + "df_val.loc[:, 'tanimoto'] = np.nan\n", "for i, d in df_val.iterrows():\n", - " df_val.at[i, \"tanimoto\"] = (\n", - " df_train[\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " df_val.at[i, 'tanimoto'] = (\n", + " df_train['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite']))\n", " .max()\n", " )" ] @@ -4964,8 +4964,8 @@ "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", "sns.boxplot(\n", " data=df_cas,\n", - " x=pd.cut(df_cas[\"tanimoto\"], bins=[x / 10.0 for x in list(range(0, 10, 1))]),\n", - " y=\"merged_sqrt_cosine\",\n", + " x=pd.cut(df_cas['tanimoto'], bins=[x / 10.0 for x in list(range(0, 10, 1))]),\n", + " y='merged_sqrt_cosine',\n", ") # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", "plt.ylim([0, 1])\n", "plt.show()\n", @@ -4974,8 +4974,8 @@ "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", "sns.boxplot(\n", " data=df_val,\n", - " x=pd.cut(df_val[\"tanimoto\"], bins=[x / 10.0 for x in list(range(0, 10, 1))]),\n", - " y=\"spectral_sqrt_cosine\",\n", + " x=pd.cut(df_val['tanimoto'], bins=[x / 10.0 for x in list(range(0, 10, 1))]),\n", + " y='spectral_sqrt_cosine',\n", ") # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", "plt.ylim([0, 1])\n", "plt.show()" @@ -4991,11 +4991,11 @@ "C = pd.concat([df_val, df_cas, df_cas22])\n", "sns.pointplot(\n", " data=C,\n", - " x=pd.cut(C[\"tanimoto\"], bins=[x / 10.0 for x in list(range(0, 10, 1))]),\n", - " y=\"spectral_sqrt_cosine\",\n", + " x=pd.cut(C['tanimoto'], bins=[x / 10.0 for x in list(range(0, 10, 1))]),\n", + " y='spectral_sqrt_cosine',\n", " palette=tri_palette,\n", " capsize=0.2,\n", - " hue=\"test_set\",\n", + " hue='test_set',\n", " dodge=0.3,\n", ") # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", "plt.ylim([0, 1])\n", @@ -5016,38 +5016,36 @@ "metadata": {}, "outputs": [], "source": [ - "import scipy\n", - "\n", - "df_val_unique = df_val.drop_duplicates(subset=\"SMILES\", keep=\"first\")\n", - "df_cas22_unique = df_cas22.drop_duplicates(subset=\"SMILES\", keep=\"first\")\n", + "df_val_unique = df_val.drop_duplicates(subset='SMILES', keep='first')\n", + "df_cas22_unique = df_cas22.drop_duplicates(subset='SMILES', keep='first')\n", "\n", "#\n", "# Recalibration\n", "#\n", "\n", "\n", - "df_val_unique[\"RT_pred_cal\"] = df_val_unique[\"RT_pred\"]\n", + "df_val_unique['RT_pred_cal'] = df_val_unique['RT_pred']\n", "\n", "confident_ids = (\n", - " df_cas.sort_values(by=\"merged_sqrt_cosine\", ascending=False)\n", - " .dropna(subset=[\"RETENTIONTIME\"])\n", + " df_cas.sort_values(by='merged_sqrt_cosine', ascending=False)\n", + " .dropna(subset=['RETENTIONTIME'])\n", " .head(10)\n", ")\n", "rt_slope, rt_intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", - " confident_ids[\"RT_pred\"].astype(float), confident_ids[\"RETENTIONTIME\"]\n", + " confident_ids['RT_pred'].astype(float), confident_ids['RETENTIONTIME']\n", ")\n", - "df_cas[\"RT_pred_cal\"] = rt_intercept + df_cas[\"RT_pred\"] * rt_slope\n", + "df_cas['RT_pred_cal'] = rt_intercept + df_cas['RT_pred'] * rt_slope\n", "\n", "\n", "confident_ids = (\n", - " df_cas22_unique.sort_values(by=\"spectral_sqrt_cosine\", ascending=False)\n", - " .dropna(subset=[\"RETENTIONTIME\"])\n", + " df_cas22_unique.sort_values(by='spectral_sqrt_cosine', ascending=False)\n", + " .dropna(subset=['RETENTIONTIME'])\n", " .head(10)\n", ")\n", "rt_slope, rt_intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", - " confident_ids[\"RT_pred\"].astype(float), confident_ids[\"RETENTIONTIME\"]\n", + " confident_ids['RT_pred'].astype(float), confident_ids['RETENTIONTIME']\n", ")\n", - "df_cas22_unique[\"RT_pred_cal\"] = rt_intercept + df_cas22_unique[\"RT_pred\"] * rt_slope" + "df_cas22_unique['RT_pred_cal'] = rt_intercept + df_cas22_unique['RT_pred'] * rt_slope" ] }, { @@ -5058,16 +5056,16 @@ "source": [ "RT = pd.concat(\n", " [\n", - " df_val_unique[[\"RETENTIONTIME\", \"RT_pred\", \"RT_pred_cal\", \"RT_dif\", \"library\"]],\n", - " df_cas[[\"RETENTIONTIME\", \"RT_pred\", \"RT_pred_cal\", \"RT_dif\", \"library\"]],\n", + " df_val_unique[['RETENTIONTIME', 'RT_pred', 'RT_pred_cal', 'RT_dif', 'library']],\n", + " df_cas[['RETENTIONTIME', 'RT_pred', 'RT_pred_cal', 'RT_dif', 'library']],\n", " df_cas22_unique[\n", - " [\"RETENTIONTIME\", \"RT_pred\", \"RT_pred_cal\", \"RT_dif\", \"library\"]\n", + " ['RETENTIONTIME', 'RT_pred', 'RT_pred_cal', 'RT_dif', 'library']\n", " ],\n", " ],\n", " ignore_index=True,\n", ")\n", - "RT = RT.dropna(subset=[\"RETENTIONTIME\"])\n", - "RT[\"RT_rel_dif\"] = RT[\"RT_dif\"] / RT[\"RETENTIONTIME\"]" + "RT = RT.dropna(subset=['RETENTIONTIME'])\n", + "RT['RT_rel_dif'] = RT['RT_dif'] / RT['RETENTIONTIME']" ] }, { @@ -5077,32 +5075,32 @@ "outputs": [], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12, 8), gridspec_kw={\"width_ratios\": [5, 2]}, sharey=False\n", + " 1, 2, figsize=(12, 8), gridspec_kw={'width_ratios': [5, 2]}, sharey=False\n", ")\n", "plt.subplots_adjust(wspace=0.1)\n", "\n", "sns.scatterplot(\n", " ax=axs[0],\n", " data=RT,\n", - " x=\"RETENTIONTIME\",\n", - " y=\"RT_pred_cal\",\n", - " hue=\"library\",\n", + " x='RETENTIONTIME',\n", + " y='RT_pred_cal',\n", + " hue='library',\n", " palette=tri_palette,\n", - " style=\"library\",\n", - " color=\"gray\",\n", + " style='library',\n", + " color='gray',\n", ")\n", "axs[0].set_ylim([0, 30])\n", "axs[0].set_xlim([0, 30])\n", - "axs[0].set_ylabel(\"Predicted retention time\")\n", - "axs[0].set_xlabel(\"Observed retention time\")\n", - "sns.lineplot(ax=axs[0], x=[0, 100], y=[0, 100], color=\"black\")\n", + "axs[0].set_ylabel('Predicted retention time')\n", + "axs[0].set_xlabel('Observed retention time')\n", + "sns.lineplot(ax=axs[0], x=[0, 100], y=[0, 100], color='black')\n", "\n", "sns.boxplot(\n", - " ax=axs[1], data=RT, y=\"RT_dif\", palette=tri_palette, x=\"library\", showfliers=False\n", + " ax=axs[1], data=RT, y='RT_dif', palette=tri_palette, x='library', showfliers=False\n", ")\n", - "axs[1].set_xlabel(\"\")\n", - "axs[1].set_ylabel(\"RT difference (in minutes)\")\n", - "axs[1].yaxis.set_label_position(\"right\")\n", + "axs[1].set_xlabel('')\n", + "axs[1].set_ylabel('RT difference (in minutes)')\n", + "axs[1].yaxis.set_label_position('right')\n", "axs[1].yaxis.tick_right()\n", "plt.show()" ] @@ -5114,17 +5112,17 @@ "outputs": [], "source": [ "print(\n", - " np.corrcoef(RT[\"RETENTIONTIME\"].values, RT[\"RT_pred_cal\"].values, dtype=float)[0, 1]\n", + " np.corrcoef(RT['RETENTIONTIME'].values, RT['RT_pred_cal'].values, dtype=float)[0, 1]\n", ")\n", "print(\n", " np.corrcoef(\n", - " df_val_unique.dropna(subset=[\"RETENTIONTIME\"])[\"RETENTIONTIME\"],\n", - " df_val_unique.dropna(subset=[\"RETENTIONTIME\"])[\"RT_pred\"],\n", + " df_val_unique.dropna(subset=['RETENTIONTIME'])['RETENTIONTIME'],\n", + " df_val_unique.dropna(subset=['RETENTIONTIME'])['RT_pred'],\n", " dtype=float,\n", " )[0, 1],\n", - " np.corrcoef(df_cas[\"RETENTIONTIME\"], df_cas[\"RT_pred\"], dtype=float)[0, 1],\n", + " np.corrcoef(df_cas['RETENTIONTIME'], df_cas['RT_pred'], dtype=float)[0, 1],\n", " np.corrcoef(\n", - " df_cas22_unique[\"RETENTIONTIME\"], df_cas22_unique[\"RT_pred\"], dtype=float\n", + " df_cas22_unique['RETENTIONTIME'], df_cas22_unique['RT_pred'], dtype=float\n", " )[0, 1],\n", ")" ] @@ -5136,7 +5134,7 @@ "outputs": [], "source": [ "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + " 2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True\n", ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", @@ -5145,31 +5143,31 @@ "sns.kdeplot(\n", " ax=axs[0],\n", " data=RT,\n", - " x=\"RETENTIONTIME\",\n", + " x='RETENTIONTIME',\n", " bw_adjust=0.25,\n", - " color=\"black\",\n", - " multiple=\"layer\",\n", - " hue=\"library\",\n", + " color='black',\n", + " multiple='layer',\n", + " hue='library',\n", " palette=tri_palette,\n", ") # hue=\"Precursor_type\",\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "\n", "\n", "sns.scatterplot(\n", - " ax=axs[1], data=df_val_unique, x=\"RETENTIONTIME\", y=\"RT_pred_cal\", color=\"gray\"\n", + " ax=axs[1], data=df_val_unique, x='RETENTIONTIME', y='RT_pred_cal', color='gray'\n", ") # , hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", - "axs[1].set_ylim([0, df_val_unique[\"RETENTIONTIME\"].max() + 1])\n", - "axs[1].set_xlim([0, df_val_unique[\"RETENTIONTIME\"].max() + 1])\n", - "axs[1].set_ylabel(\"Predicted retention time\")\n", - "axs[1].set_xlabel(\"Observed retention time\")\n", + "axs[1].set_ylim([0, df_val_unique['RETENTIONTIME'].max() + 1])\n", + "axs[1].set_xlim([0, df_val_unique['RETENTIONTIME'].max() + 1])\n", + "axs[1].set_ylabel('Predicted retention time')\n", + "axs[1].set_xlabel('Observed retention time')\n", "line = [0, 100]\n", - "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", + "sns.lineplot(ax=axs[1], x=line, y=line, color='black')\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[x + 20 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[x + 20 / 60.0 for x in line], color='black', linestyle='--'\n", ")\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[x - 20 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[x - 20 / 60.0 for x in line], color='black', linestyle='--'\n", ")" ] }, @@ -5194,7 +5192,7 @@ "outputs": [], "source": [ "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + " 2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True\n", ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", @@ -5203,9 +5201,9 @@ "\n", "CCS = pd.concat(\n", " [\n", - " df_val_unique[[\"CCS\", \"CCS_pred\", \"library\"]],\n", - " df_cas[[\"CCS\", \"CCS_pred\", \"library\"]],\n", - " df_cas22_unique[[\"CCS\", \"CCS_pred\", \"library\"]],\n", + " df_val_unique[['CCS', 'CCS_pred', 'library']],\n", + " df_cas[['CCS', 'CCS_pred', 'library']],\n", + " df_cas22_unique[['CCS', 'CCS_pred', 'library']],\n", " ],\n", " ignore_index=True,\n", ")\n", @@ -5213,34 +5211,34 @@ "\n", "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", - " ax=axs[0], data=CCS, x=\"CCS\", bw_adjust=0.25, color=\"black\", multiple=\"stack\"\n", + " ax=axs[0], data=CCS, x='CCS', bw_adjust=0.25, color='black', multiple='stack'\n", ") # hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "\n", "\n", "sns.scatterplot(\n", " ax=axs[1],\n", " data=CCS,\n", - " x=\"CCS\",\n", - " y=\"CCS_pred\",\n", + " x='CCS',\n", + " y='CCS_pred',\n", " s=25,\n", - " hue=\"library\",\n", + " hue='library',\n", " palette=tri_palette,\n", - " style=\"library\",\n", - " color=\"gray\",\n", + " style='library',\n", + " color='gray',\n", ") # , color=\"blue\", edgecolor=\"blue\")#,\n", - "axs[1].set_ylim([df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10])\n", - "axs[1].set_xlim([df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10])\n", - "axs[1].set_ylabel(\"Predicted CCS\")\n", - "axs[1].set_xlabel(\"Observed CCS\")\n", - "line = [df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10]\n", - "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", + "axs[1].set_ylim([df_val_unique['CCS'].min() - 10, df_val_unique['CCS'].max() + 10])\n", + "axs[1].set_xlim([df_val_unique['CCS'].min() - 10, df_val_unique['CCS'].max() + 10])\n", + "axs[1].set_ylabel('Predicted CCS')\n", + "axs[1].set_xlabel('Observed CCS')\n", + "line = [df_val_unique['CCS'].min() - 10, df_val_unique['CCS'].max() + 10]\n", + "sns.lineplot(ax=axs[1], x=line, y=line, color='black')\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[1.1 * x for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[1.1 * x for x in line], color='black', linestyle='--'\n", ")\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[0.9 * x for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[0.9 * x for x in line], color='black', linestyle='--'\n", ")\n", "plt.show()" ] @@ -5270,7 +5268,7 @@ "outputs": [], "source": [ "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + " 2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True\n", ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", @@ -5279,9 +5277,9 @@ "\n", "CCS = pd.concat(\n", " [\n", - " df_val_unique[[\"CCS\", \"CCS_pred\", \"library\"]],\n", - " df_cas[[\"CCS\", \"CCS_pred\", \"library\"]],\n", - " df_cas22_unique[[\"CCS\", \"CCS_pred\", \"library\"]],\n", + " df_val_unique[['CCS', 'CCS_pred', 'library']],\n", + " df_cas[['CCS', 'CCS_pred', 'library']],\n", + " df_cas22_unique[['CCS', 'CCS_pred', 'library']],\n", " ],\n", " ignore_index=True,\n", ")\n", @@ -5289,34 +5287,34 @@ "\n", "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", - " ax=axs[0], data=CCS, x=\"CCS\", bw_adjust=0.25, color=\"black\", multiple=\"stack\"\n", + " ax=axs[0], data=CCS, x='CCS', bw_adjust=0.25, color='black', multiple='stack'\n", ") # hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "\n", "\n", "sns.scatterplot(\n", " ax=axs[1],\n", " data=CCS,\n", - " x=\"CCS\",\n", - " y=\"CCS_pred\",\n", + " x='CCS',\n", + " y='CCS_pred',\n", " s=25,\n", - " hue=\"library\",\n", + " hue='library',\n", " palette=tri_palette,\n", - " style=\"library\",\n", - " color=\"gray\",\n", + " style='library',\n", + " color='gray',\n", ") # , color=\"blue\", edgecolor=\"blue\")#,\n", - "axs[1].set_ylim([df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10])\n", - "axs[1].set_xlim([df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10])\n", - "axs[1].set_ylabel(\"Predicted CCS\")\n", - "axs[1].set_xlabel(\"Observed CCS\")\n", - "line = [df_val_unique[\"CCS\"].min() - 10, df_val_unique[\"CCS\"].max() + 10]\n", - "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", + "axs[1].set_ylim([df_val_unique['CCS'].min() - 10, df_val_unique['CCS'].max() + 10])\n", + "axs[1].set_xlim([df_val_unique['CCS'].min() - 10, df_val_unique['CCS'].max() + 10])\n", + "axs[1].set_ylabel('Predicted CCS')\n", + "axs[1].set_xlabel('Observed CCS')\n", + "line = [df_val_unique['CCS'].min() - 10, df_val_unique['CCS'].max() + 10]\n", + "sns.lineplot(ax=axs[1], x=line, y=line, color='black')\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[1.1 * x for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[1.1 * x for x in line], color='black', linestyle='--'\n", ")\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[0.9 * x for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[0.9 * x for x in line], color='black', linestyle='--'\n", ")\n", "plt.show()" ] @@ -5327,42 +5325,42 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"Pearson Corr Coef:\")\n", + "print('Pearson Corr Coef:')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " np.corrcoef(\n", - " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_val_unique.dropna(subset=['CCS'])['CCS'],\n", + " df_val_unique.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "print(\n", - " \"pMZ\",\n", + " 'pMZ',\n", " np.corrcoef(\n", - " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_val_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " df_val_unique.dropna(subset=['CCS'])['CCS'],\n", + " df_val_unique.dropna(subset=['CCS'])['PrecursorMZ'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "\n", "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", - " df_train.dropna(subset=[\"CCS\"])[\"PRECURSORMZ\"],\n", - " df_train.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_train.dropna(subset=['CCS'])['PRECURSORMZ'],\n", + " df_train.dropna(subset=['CCS'])['CCS'],\n", ")\n", - "print(\"R2\")\n", + "print('R2')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " r2_score(\n", - " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_val_unique.dropna(subset=['CCS'])['CCS'],\n", + " df_val_unique.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " ),\n", ")\n", "print(\n", - " \"pMZ\",\n", + " 'pMZ',\n", " r2_score(\n", - " df_val_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_val_unique.dropna(subset=['CCS'])['CCS'],\n", " intercept\n", - " + slope * df_val_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " + slope * df_val_unique.dropna(subset=['CCS'])['PrecursorMZ'].dropna(),\n", " ),\n", ")" ] @@ -5373,41 +5371,41 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"Pearson Corr Coef:\")\n", + "print('Pearson Corr Coef:')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " np.corrcoef(\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_cas.dropna(subset=['CCS'])['CCS'],\n", + " df_cas.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "print(\n", - " \"pMZ\",\n", + " 'pMZ',\n", " np.corrcoef(\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " df_cas.dropna(subset=['CCS'])['CCS'],\n", + " df_cas.dropna(subset=['CCS'])['PRECURSOR_MZ'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "\n", "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", - " df_train.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"],\n", - " df_train.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_train.dropna(subset=['CCS'])['PrecursorMZ'],\n", + " df_train.dropna(subset=['CCS'])['CCS'],\n", ")\n", - "print(\"R2\")\n", + "print('R2')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " r2_score(\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_cas.dropna(subset=['CCS'])['CCS'],\n", + " df_cas.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " ),\n", ")\n", "print(\n", - " \"pMZ\",\n", + " 'pMZ',\n", " r2_score(\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " intercept + slope * df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " df_cas.dropna(subset=['CCS'])['CCS'],\n", + " intercept + slope * df_cas.dropna(subset=['CCS'])['PRECURSOR_MZ'].dropna(),\n", " ),\n", ")" ] @@ -5482,53 +5480,53 @@ "metadata": {}, "outputs": [], "source": [ - "def double_mirrorplot22(i, model_title=\"Fiora\"):\n", + "def double_mirrorplot22(i, model_title='Fiora'):\n", " fig, axs = plt.subplots(\n", - " 1, 3, figsize=(16.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3, 3]}, sharey=False\n", + " 1, 3, figsize=(16.8, 4.2), gridspec_kw={'width_ratios': [1, 3, 3]}, sharey=False\n", " )\n", "\n", " plt.subplots_adjust(right=0.975, left=0.025)\n", "\n", - " img = df_cas22.iloc[i][\"Metabolite\"].draw(ax=axs[0])\n", + " img = df_cas22.iloc[i]['Metabolite'].draw(ax=axs[0])\n", "\n", " axs[0].grid(False)\n", " axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", " )\n", - " axs[0].set_title(df_cas22.iloc[i][\"ChallengeName\"])\n", + " axs[0].set_title(df_cas22.iloc[i]['ChallengeName'])\n", " axs[0].imshow(img)\n", - " axs[0].axis(\"off\")\n", + " axs[0].axis('off')\n", "\n", " sv.plot_spectrum(\n", - " {\"peaks\": df_cas22.iloc[i][\"peaks\"]},\n", - " {\"peaks\": df_cas22.iloc[i][\"sim_peaks\"]},\n", + " {'peaks': df_cas22.iloc[i]['peaks']},\n", + " {'peaks': df_cas22.iloc[i]['sim_peaks']},\n", " ax=axs[1],\n", " )\n", " axs[1].title.set_text(model_title)\n", " patch1 = mpatches.Patch(\n", - " color=\"limegreen\"\n", - " if df_cas22.iloc[i][\"cfm_sqrt_cosine\"]\n", - " < df_cas22.iloc[i][\"spectral_sqrt_cosine\"]\n", - " else \"orangered\",\n", - " label=f\"cosine {df_cas22.iloc[i]['spectral_sqrt_cosine']:.02f}\",\n", + " color='limegreen'\n", + " if df_cas22.iloc[i]['cfm_sqrt_cosine']\n", + " < df_cas22.iloc[i]['spectral_sqrt_cosine']\n", + " else 'orangered',\n", + " label=f'cosine {df_cas22.iloc[i][\"spectral_sqrt_cosine\"]:.02f}',\n", " )\n", " axs[1].legend(handles=[patch1])\n", "\n", " sv.plot_spectrum(\n", - " {\"peaks\": df_cas22.iloc[i][\"peaks\"]},\n", - " {\"peaks\": df_cas22.iloc[i][\"cfm_peaks\"]}\n", - " if df_cas22.iloc[i][\"cfm_peaks\"]\n", - " else {\"peaks\": {\"mz\": [0], \"intensity\": [0]}},\n", + " {'peaks': df_cas22.iloc[i]['peaks']},\n", + " {'peaks': df_cas22.iloc[i]['cfm_peaks']}\n", + " if df_cas22.iloc[i]['cfm_peaks']\n", + " else {'peaks': {'mz': [0], 'intensity': [0]}},\n", " ax=axs[2],\n", " )\n", - " axs[2].title.set_text(f\"CFM-ID 4.4.7\")\n", + " axs[2].title.set_text('CFM-ID 4.4.7')\n", "\n", " patch2 = mpatches.Patch(\n", - " color=\"limegreen\"\n", - " if df_cas22.iloc[i][\"cfm_sqrt_cosine\"]\n", - " > df_cas22.iloc[i][\"spectral_sqrt_cosine\"]\n", - " else \"orangered\",\n", - " label=f\"cosine {df_cas22.iloc[i]['cfm_sqrt_cosine']:.02f}\",\n", + " color='limegreen'\n", + " if df_cas22.iloc[i]['cfm_sqrt_cosine']\n", + " > df_cas22.iloc[i]['spectral_sqrt_cosine']\n", + " else 'orangered',\n", + " label=f'cosine {df_cas22.iloc[i][\"cfm_sqrt_cosine\"]:.02f}',\n", " )\n", " axs[2].legend(handles=[patch2])\n", "\n", @@ -5568,7 +5566,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_val[\"pp\"] = df_val[\"Metabolite\"].apply(\n", + "df_val['pp'] = df_val['Metabolite'].apply(\n", " lambda x: (x.precursor_count / (sum(x.compiled_countsALL) / 2.0)).tolist()\n", ")" ] @@ -5581,11 +5579,11 @@ "source": [ "sns.histplot(\n", " df_val,\n", - " x=\"pp\",\n", - " hue=\"PRECURSORTYPE\",\n", - " multiple=\"dodge\",\n", + " x='pp',\n", + " hue='PRECURSORTYPE',\n", + " multiple='dodge',\n", " common_norm=False,\n", - " stat=\"density\",\n", + " stat='density',\n", ")" ] }, @@ -5596,7 +5594,7 @@ "outputs": [], "source": [ "sns.kdeplot(\n", - " data=df_val, x=\"pp\", y=\"CE\", hue=\"Precursor_type\"\n", + " data=df_val, x='pp', y='CE', hue='Precursor_type'\n", ") # multiple=\"dodge\", common_norm=False, stat=\"density\")" ] }, @@ -5640,7 +5638,7 @@ } ], "source": [ - "df[\"dataset\"].value_counts()" + "df['dataset'].value_counts()" ] }, { diff --git a/notebooks/grid_search.ipynb b/notebooks/grid_search.ipynb index d4e58c6..337d050 100644 --- a/notebooks/grid_search.ipynb +++ b/notebooks/grid_search.ipynb @@ -24,6 +24,7 @@ ], "source": [ "import sys\n", + "\n", "import torch\n", "\n", "seed = 42\n", @@ -32,28 +33,29 @@ "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", "\n", - "import pandas as pd\n", - "import numpy as np\n", "import ast\n", "import copy\n", "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", - "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree\n", - "import fiora.visualization.spectrum_visualizer as sv\n", - "\n", - "from sklearn.metrics import r2_score\n", + "home = expanduser('~')\n", "import scipy\n", "from rdkit import RDLogger\n", + "from sklearn.metrics import r2_score\n", + "\n", + "import fiora.visualization.spectrum_visualizer as sv\n", + "from fiora.IO.LibraryLoader import LibraryLoader\n", + "from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM, PPM\n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f\"Working with Python {sys.version}\")" + "print(f'Working with Python {sys.version}')" ] }, { @@ -80,13 +82,13 @@ "source": [ "from typing import Literal\n", "\n", - "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\"] = \"NIST/MSDIAL\"\n", - "print(f\"Preparing {lib} library\")\n", + "lib: Literal['NIST', 'MSDIAL', 'NIST/MSDIAL'] = 'NIST/MSDIAL'\n", + "print(f'Preparing {lib} library')\n", "\n", "test_run = False # Default: False\n", "if test_run:\n", " print(\n", - " \"+++ This is a test run with a small subset of data points. Results are not representative. +++\"\n", + " '+++ This is a test run with a small subset of data points. Results are not representative. +++'\n", " )" ] }, @@ -98,14 +100,14 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\",\n", + " 'name': 'Name',\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'ionization': 'Ionization',\n", + " 'precursor_mz': 'PrecursorMZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", + " 'ccs': 'CCS',\n", "}\n", "\n", "\n", @@ -116,19 +118,19 @@ "\n", "def load_training_data():\n", " L = LibraryLoader()\n", - " df = L.load_from_csv(f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\")\n", + " df = L.load_from_csv(f'{home}/data/metabolites/preprocessed/datasplits_Jan24.csv')\n", " return df\n", "\n", "\n", "df = load_training_data()\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"summary\"]\n", + "dict_columns = ['peaks', 'summary']\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", " # df[col] = df[col].apply(ast.literal_eval)\n", "\n", - "df[\"group_id\"] = df[\"group_id\"].astype(int)" + "df['group_id'] = df['group_id'].astype(int)" ] }, { @@ -141,7 +143,7 @@ "# TODO: ATTENTION\n", "#\n", "\n", - "df[\"ppm_peak_tolerance\"] = 100 * PPM" + "df['ppm_peak_tolerance'] = 100 * PPM" ] }, { @@ -151,11 +153,11 @@ "outputs": [], "source": [ "%%capture\n", - "from fiora.MOL.Metabolite import Metabolite\n", - "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", - "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", "from fiora.GNN.SetupFeatureEncoder import SetupFeatureEncoder\n", "\n", + "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", + "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", + "from fiora.MOL.Metabolite import Metabolite\n", "\n", "CE_upper_limit = 100.0\n", "weight_upper_limit = 1000.0\n", @@ -166,38 +168,38 @@ " # df = df.iloc[5000:20000,:]\n", "\n", "\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", - "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df['Metabolite'] = df['SMILES'].apply(Metabolite)\n", + "df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "setup_encoder = SetupFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", " ]\n", ")\n", "rt_encoder = SetupFeatureEncoder(\n", - " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"]\n", + " feature_list=['molecular_weight', 'precursor_mode', 'instrument']\n", ")\n", "\n", - "setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - "setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "setup_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + "setup_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + "rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", "\n", - "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", + "df['Metabolite'].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", + "df.apply(lambda x: x['Metabolite'].set_id(x['group_id']), axis=1)\n", "\n", "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", ")\n", - "df[\"num_per_group\"] = df[\"group_id\"].map(df[\"group_id\"].value_counts())\n", - "df[\"loss_weight\"] = 1.0 / df[\"num_per_group\"]\n", - "df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)" + "df['num_per_group'] = df['group_id'].map(df['group_id'].value_counts())\n", + "df['loss_weight'] = 1.0 / df['num_per_group']\n", + "df.apply(lambda x: x['Metabolite'].set_loss_weight(x['loss_weight']), axis=1)" ] }, { @@ -207,10 +209,10 @@ "outputs": [], "source": [ "%%capture\n", - "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=x['ppm_peak_tolerance']\n", " ),\n", " axis=1,\n", ")" @@ -238,12 +240,12 @@ "#\n", "\n", "\n", - "df[\"coverage\"] = df[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", + "df['coverage'] = df['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", "\n", - "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", "\n", - "sns.violinplot(df, y=\"coverage\", hue=\"lib\", split=True)\n", + "sns.violinplot(df, y='coverage', hue='lib', split=True)\n", "plt.show()" ] }, @@ -276,7 +278,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = df[df[\"coverage\"] > 0.5] # TODO: ATTENTION 50 ppm + 50% cov cutoff" + "df = df[df['coverage'] > 0.5] # TODO: ATTENTION 50 ppm + 50% cov cutoff" ] }, { @@ -315,18 +317,18 @@ "metadata": {}, "outputs": [], "source": [ - "casmi16_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv\"\n", - "casmi22_path = f\"{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv\"\n", + "casmi16_path = f'{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv'\n", + "casmi22_path = f'{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv'\n", "\n", "df_cas = pd.read_csv(casmi16_path, index_col=[0], low_memory=False)\n", "df_cas22 = pd.read_csv(casmi22_path, index_col=[0], low_memory=False)\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"Candidates\"]\n", + "dict_columns = ['peaks', 'Candidates']\n", "for col in dict_columns:\n", " df_cas[col] = df_cas[col].apply(ast.literal_eval)\n", "\n", - "df_cas22[\"peaks\"] = df_cas22[\"peaks\"].apply(ast.literal_eval)" + "df_cas22['peaks'] = df_cas22['peaks'].apply(ast.literal_eval)" ] }, { @@ -338,36 +340,36 @@ "%%capture\n", "from fiora.MOL.collision_energy import NCE_to_eV\n", "\n", - "df_cas[\"RETENTIONTIME\"] = df_cas[\"RTINSECONDS\"] / 60.0\n", - "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cas['RETENTIONTIME'] = df_cas['RTINSECONDS'] / 60.0\n", + "df_cas['Metabolite'] = df_cas['SMILES'].apply(Metabolite)\n", + "df_cas['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas[\"Metabolite\"].apply(\n", + "df_cas['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "df_cas['CE'] = 20.0 # actually stepped 20/35/50\n", + "df_cas['Instrument_type'] = 'HCD' # CHECK if correct Orbitrap\n", "\n", "metadata_key_map16 = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'PRECURSOR_MZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", "}\n", "\n", - "df_cas[\"summary\"] = df_cas.apply(\n", + "df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", ")\n", "df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder), axis=1\n", ")\n", "\n", "# Fragmentation\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cas['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=100 * PPM\n", " ),\n", " axis=1,\n", ") # Optional: use mz_cut instead\n", @@ -376,37 +378,37 @@ "# CASMI 22\n", "#\n", "\n", - "df_cas22[\"Metabolite\"] = df_cas22[\"SMILES\"].apply(Metabolite)\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cas22['Metabolite'] = df_cas22['SMILES'].apply(Metabolite)\n", + "df_cas22['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas22[\"Metabolite\"].apply(\n", + "df_cas22['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cas22[\"CE\"] = df_cas22.apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1\n", + "df_cas22['CE'] = df_cas22.apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['precursor_mz']), axis=1\n", ")\n", "\n", "metadata_key_map22 = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"precursor_mz\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"ChallengeRT\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'precursor_mz',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'ChallengeRT',\n", "}\n", "\n", - "df_cas22[\"summary\"] = df_cas22.apply(\n", + "df_cas22['summary'] = df_cas22.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map22.items()}, axis=1\n", ")\n", "df_cas22.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", ")\n", "\n", "# Fragmentation\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cas22['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_cas22.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=100 * PPM\n", " ),\n", " axis=1,\n", ") # Optional: use mz_cut instead\n", @@ -436,16 +438,17 @@ } ], "source": [ - "from fiora.GNN.Trainer import Trainer\n", "import torch_geometric as geom\n", "\n", + "from fiora.GNN.Trainer import Trainer\n", + "\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", - " dev = \"cuda:1\"\n", + " dev = 'cuda:1'\n", "else:\n", - " dev = \"cpu\"\n", + " dev = 'cpu'\n", "\n", - "print(f\"Running on device: {dev}\")" + "print(f'Running on device: {dev}')" ] }, { @@ -474,10 +477,10 @@ } ], "source": [ - "print(df.groupby(\"dataset\")[\"group_id\"].unique().apply(len))\n", + "print(df.groupby('dataset')['group_id'].unique().apply(len))\n", "\n", - "df_test = df[df[\"dataset\"] == \"test\"]\n", - "df_train = df[df[\"dataset\"].isin([\"training\", \"validation\"])]" + "df_test = df[df['dataset'] == 'test']\n", + "df_train = df[df['dataset'].isin(['training', 'validation'])]" ] }, { @@ -494,8 +497,8 @@ } ], "source": [ - "geo_data = df_train[\"Metabolite\"].apply(lambda x: x.as_geometric_data().to(dev)).values\n", - "print(f\"Prepared training/validation with {len(geo_data)} data points\")" + "geo_data = df_train['Metabolite'].apply(lambda x: x.as_geometric_data().to(dev)).values\n", + "print(f'Prepared training/validation with {len(geo_data)} data points')" ] }, { @@ -513,28 +516,28 @@ "outputs": [], "source": [ "model_params = {\n", - " \"param_tag\": \"default\",\n", - " \"gnn_type\": \"RGCNConv\",\n", - " \"depth\": 5,\n", - " \"hidden_dimension\": 300,\n", - " \"dense_layers\": 2,\n", - " \"embedding_aggregation\": \"concat\",\n", - " \"embedding_dimension\": 300,\n", - " \"input_dropout\": 0.2,\n", - " \"latent_dropout\": 0.1,\n", - " \"node_feature_layout\": node_encoder.feature_numbers,\n", - " \"edge_feature_layout\": bond_encoder.feature_numbers,\n", - " \"static_feature_dimension\": geo_data[0][\"static_edge_features\"].shape[1],\n", - " \"static_rt_feature_dimension\": geo_data[0][\"static_rt_features\"].shape[1],\n", - " \"output_dimension\": len(DEFAULT_MODES) * 2, # per edge\n", + " 'param_tag': 'default',\n", + " 'gnn_type': 'RGCNConv',\n", + " 'depth': 5,\n", + " 'hidden_dimension': 300,\n", + " 'dense_layers': 2,\n", + " 'embedding_aggregation': 'concat',\n", + " 'embedding_dimension': 300,\n", + " 'input_dropout': 0.2,\n", + " 'latent_dropout': 0.1,\n", + " 'node_feature_layout': node_encoder.feature_numbers,\n", + " 'edge_feature_layout': bond_encoder.feature_numbers,\n", + " 'static_feature_dimension': geo_data[0]['static_edge_features'].shape[1],\n", + " 'static_rt_feature_dimension': geo_data[0]['static_rt_features'].shape[1],\n", + " 'output_dimension': len(DEFAULT_MODES) * 2, # per edge\n", "}\n", "training_params = {\n", - " \"epochs\": 200 if not test_run else 10,\n", - " \"batch_size\": 256, # 128,\n", + " 'epochs': 200 if not test_run else 10,\n", + " 'batch_size': 256, # 128,\n", " #'train_val_split': 0.90,\n", - " \"learning_rate\": 0.0004, # 0.001,\n", - " \"with_RT\": False,\n", - " \"with_CCS\": False,\n", + " 'learning_rate': 0.0004, # 0.001,\n", + " 'with_RT': False,\n", + " 'with_CCS': False,\n", "}" ] }, @@ -544,17 +547,17 @@ "metadata": {}, "outputs": [], "source": [ - "fixed_params = {\"gnn_type\": \"RGCNConv\"} # Mainly used for clarity\n", + "fixed_params = {'gnn_type': 'RGCNConv'} # Mainly used for clarity\n", "grid_params = [\n", - " {\"depth\": 0},\n", - " {\"depth\": 1},\n", - " {\"depth\": 2},\n", - " {\"depth\": 3},\n", - " {\"depth\": 4},\n", - " {\"depth\": 5},\n", - " {\"depth\": 6},\n", - " {\"depth\": 7},\n", - " {\"depth\": 8},\n", + " {'depth': 0},\n", + " {'depth': 1},\n", + " {'depth': 2},\n", + " {'depth': 3},\n", + " {'depth': 4},\n", + " {'depth': 5},\n", + " {'depth': 6},\n", + " {'depth': 7},\n", + " {'depth': 8},\n", "] # ,{'depth': 0}, {'depth': 1}, {'depth': 7}, {'depth': 8}]\n", "for p in grid_params:\n", " p.update(fixed_params)\n", @@ -577,35 +580,36 @@ "outputs": [], "source": [ "from fiora.GNN.GNNModules import GNNCompiler\n", + "\n", "from fiora.GNN.Losses import WeightedMSELoss, WeightedMSEMetric\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", "fiora = SimulationFramework(\n", " None,\n", " dev=dev,\n", - " with_RT=training_params[\"with_RT\"],\n", - " with_CCS=training_params[\"with_CCS\"],\n", + " with_RT=training_params['with_RT'],\n", + " with_CCS=training_params['with_CCS'],\n", ")\n", - "np.seterr(invalid=\"ignore\")\n", + "np.seterr(invalid='ignore')\n", "val_interval = 1\n", - "tag = \"grid\"\n", + "tag = 'grid'\n", "val_interval = 1\n", - "metric_dict = {\"mse\": WeightedMSEMetric}\n", + "metric_dict = {'mse': WeightedMSEMetric}\n", "loss_fn = WeightedMSELoss()\n", "\n", "\n", "def train_new_model():\n", " model = GNNCompiler(model_params).to(dev)\n", "\n", - " y_label = \"compiled_probsALL\"\n", + " y_label = 'compiled_probsALL'\n", " train_keys, val_keys = (\n", - " df_train[df_train[\"dataset\"] == \"training\"][\"group_id\"].unique(),\n", - " df_train[df_train[\"dataset\"] == \"validation\"][\"group_id\"].unique(),\n", + " df_train[df_train['dataset'] == 'training']['group_id'].unique(),\n", + " df_train[df_train['dataset'] == 'validation']['group_id'].unique(),\n", " )\n", " trainer = Trainer(\n", " geo_data,\n", " y_tag=y_label,\n", - " problem_type=\"regression\",\n", + " problem_type='regression',\n", " train_keys=train_keys,\n", " val_keys=val_keys,\n", " metric_dict=metric_dict,\n", @@ -614,14 +618,14 @@ " device=dev,\n", " )\n", " optimizer = torch.optim.Adam(\n", - " model.parameters(), lr=training_params[\"learning_rate\"]\n", + " model.parameters(), lr=training_params['learning_rate']\n", " )\n", " # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n", " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # TODO doesn't work with onlyTraining\n", " optimizer,\n", " patience=8, # 10 default\n", " # factor = self.hparams['factor'],\n", - " mode=\"min\",\n", + " mode='min',\n", " verbose=True,\n", " )\n", "\n", @@ -633,11 +637,11 @@ " optimizer,\n", " loss_fn,\n", " scheduler=scheduler,\n", - " batch_size=training_params[\"batch_size\"],\n", - " epochs=training_params[\"epochs\"],\n", + " batch_size=training_params['batch_size'],\n", + " epochs=training_params['epochs'],\n", " val_every_n_epochs=val_interval,\n", - " with_CCS=training_params[\"with_CCS\"],\n", - " with_RT=training_params[\"with_RT\"],\n", + " with_CCS=training_params['with_CCS'],\n", + " with_RT=training_params['with_RT'],\n", " masked_validation=False,\n", " tag=tag,\n", " ) # , mask_name=\"compiled_validation_maskALL\")\n", @@ -652,7 +656,7 @@ "def test_model(model, DF):\n", " dft = simulate_all(model, DF)\n", "\n", - " return dft[\"spectral_sqrt_cosine\"].values" + " return dft['spectral_sqrt_cosine'].values" ] }, { @@ -668,97 +672,96 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MOL.collision_energy import NCE_to_eV\n", + "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "from fiora.MS.spectral_scores import (\n", + " reweighted_dot,\n", " spectral_cosine,\n", " spectral_reflection_cosine,\n", - " reweighted_dot,\n", ")\n", - "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", "\n", "def test_cas16(model, df_cas=df_cas):\n", "\n", - " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 20.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step1_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_20')\n", "\n", - " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 35.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step2_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_35')\n", "\n", - " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 50.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step3_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_50')\n", "\n", - " df_cas[\"avg_CE\"] = (\n", - " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " df_cas['avg_CE'] = (\n", + " df_cas['step1_CE'] + df_cas['step2_CE'] + df_cas['step3_CE']\n", " ) / 3\n", "\n", - " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " df_cas['merged_peaks'] = df_cas.apply(\n", " lambda x: merge_annotated_spectrum(\n", - " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", - " x[\"sim_peaks_50\"],\n", + " merge_annotated_spectrum(x['sim_peaks_20'], x['sim_peaks_35']),\n", + " x['sim_peaks_50'],\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['merged_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " df_cas['merged_sqrt_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks'], transform=np.sqrt),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " df_cas['merged_refl_cosine'] = df_cas.apply(\n", " lambda x: spectral_reflection_cosine(\n", - " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " x['peaks'], x['merged_peaks'], transform=np.sqrt\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_steins\"] = df_cas.apply(\n", - " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['merged_steins'] = df_cas.apply(\n", + " lambda x: reweighted_dot(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", - " \"merged_sqrt_cosine\"\n", + " df_cas['spectral_sqrt_cosine'] = df_cas[\n", + " 'merged_sqrt_cosine'\n", " ] # just remember it is merged\n", "\n", - " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - " df_cas[\"RT_pred\"] = df_cas[\"RT_pred_35\"]\n", - " df_cas[\"RT_dif\"] = df_cas[\"RT_dif_35\"]\n", - " df_cas[\"CCS_pred\"] = df_cas[\"CCS_pred_35\"]\n", - " df_cas[\"library\"] = \"CASMI-16\"\n", + " df_cas['coverage'] = df_cas['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", + " df_cas['RT_pred'] = df_cas['RT_pred_35']\n", + " df_cas['RT_dif'] = df_cas['RT_dif_35']\n", + " df_cas['CCS_pred'] = df_cas['CCS_pred_35']\n", + " df_cas['library'] = 'CASMI-16'\n", "\n", - " return df_cas[\"merged_sqrt_cosine\"].values" + " return df_cas['merged_sqrt_cosine'].values" ] }, { @@ -2855,72 +2858,72 @@ "results_cp = []\n", "\n", "for params in grid_params:\n", - " print(f\"Testing {params}\")\n", + " print(f'Testing {params}')\n", " model_params.update(params)\n", " current_model, checkpoint = train_new_model()\n", " val_results = test_model(\n", - " current_model, df_train[df_train[\"dataset\"] == \"validation\"]\n", + " current_model, df_train[df_train['dataset'] == 'validation']\n", " )\n", " test_results = test_model(current_model, df_test)\n", " casmi16_results = test_cas16(current_model)\n", - " casmi16_p = test_cas16(current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"])\n", - " casmi16_n = test_cas16(current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"])\n", + " casmi16_p = test_cas16(current_model, df_cas[df_cas['Precursor_type'] == '[M+H]+'])\n", + " casmi16_n = test_cas16(current_model, df_cas[df_cas['Precursor_type'] == '[M-H]-'])\n", " casmi22_results = test_model(current_model, df_cas22)\n", " casmi22_p = test_model(\n", - " current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"]\n", + " current_model, df_cas22[df_cas22['Precursor_type'] == '[M+H]+']\n", " )\n", " casmi22_n = test_model(\n", - " current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"]\n", + " current_model, df_cas22[df_cas22['Precursor_type'] == '[M-H]-']\n", " )\n", "\n", " results.append(\n", " {\n", " **params,\n", - " \"model\": copy.deepcopy(current_model),\n", - " \"cp\": checkpoint,\n", - " \"validation\": val_results,\n", - " \"test\": test_results,\n", - " \"casmi16\": casmi16_results,\n", - " \"casmi22\": casmi22_results,\n", - " \"casmi16+\": casmi16_p,\n", - " \"casmi16-\": casmi16_n,\n", - " \"casmi22+\": casmi22_p,\n", - " \"casmi22-\": casmi22_n,\n", + " 'model': copy.deepcopy(current_model),\n", + " 'cp': checkpoint,\n", + " 'validation': val_results,\n", + " 'test': test_results,\n", + " 'casmi16': casmi16_results,\n", + " 'casmi22': casmi22_results,\n", + " 'casmi16+': casmi16_p,\n", + " 'casmi16-': casmi16_n,\n", + " 'casmi22+': casmi22_p,\n", + " 'casmi22-': casmi22_n,\n", " }\n", " )\n", "\n", - " current_model = current_model.load(checkpoint[\"file\"])\n", + " current_model = current_model.load(checkpoint['file'])\n", " val_results_cp = test_model(\n", - " current_model, df_train[df_train[\"dataset\"] == \"validation\"]\n", + " current_model, df_train[df_train['dataset'] == 'validation']\n", " )\n", " test_results_cp = test_model(current_model, df_test)\n", " casmi16_results_cp = test_cas16(current_model)\n", " casmi16_p_cp = test_cas16(\n", - " current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"]\n", + " current_model, df_cas[df_cas['Precursor_type'] == '[M+H]+']\n", " )\n", " casmi16_n_cp = test_cas16(\n", - " current_model, df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"]\n", + " current_model, df_cas[df_cas['Precursor_type'] == '[M-H]-']\n", " )\n", " casmi22_results_cp = test_model(current_model, df_cas22)\n", " casmi22_p_cp = test_model(\n", - " current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"]\n", + " current_model, df_cas22[df_cas22['Precursor_type'] == '[M+H]+']\n", " )\n", " casmi22_n_cp = test_model(\n", - " current_model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"]\n", + " current_model, df_cas22[df_cas22['Precursor_type'] == '[M-H]-']\n", " )\n", " results_cp.append(\n", " {\n", " **params,\n", - " \"model\": copy.deepcopy(current_model),\n", - " \"cp\": checkpoint,\n", - " \"validation\": val_results_cp,\n", - " \"test\": test_results_cp,\n", - " \"casmi16\": casmi16_results_cp,\n", - " \"casmi22\": casmi22_results_cp,\n", - " \"casmi16+\": casmi16_p_cp,\n", - " \"casmi16-\": casmi16_n_cp,\n", - " \"casmi22+\": casmi22_p_cp,\n", - " \"casmi22-\": casmi22_n_cp,\n", + " 'model': copy.deepcopy(current_model),\n", + " 'cp': checkpoint,\n", + " 'validation': val_results_cp,\n", + " 'test': test_results_cp,\n", + " 'casmi16': casmi16_results_cp,\n", + " 'casmi22': casmi22_results_cp,\n", + " 'casmi16+': casmi16_p_cp,\n", + " 'casmi16-': casmi16_n_cp,\n", + " 'casmi22+': casmi22_p_cp,\n", + " 'casmi22-': casmi22_n_cp,\n", " }\n", " )" ] @@ -2954,11 +2957,11 @@ "LOG = pd.DataFrame(results)\n", "eval_columns = LOG.columns[4:]\n", "\n", - "home_path = f\"{home}/data/metabolites/benchmarking/\"\n", - "NAME = model_params[\"gnn_type\"] + \"_Jan24_test.csv\"\n", + "home_path = f'{home}/data/metabolites/benchmarking/'\n", + "NAME = model_params['gnn_type'] + '_Jan24_test.csv'\n", "for col in eval_columns:\n", " LOG[col] = LOG[col].apply(lambda x: str(list(x)))\n", - "LOG.to_csv(home_path + NAME, index=False, sep=\"\\t\")" + "LOG.to_csv(home_path + NAME, index=False, sep='\\t')" ] }, { @@ -3155,7 +3158,7 @@ } ], "source": [ - "results[2][\"cp\"]" + "results[2]['cp']" ] }, { @@ -3342,11 +3345,11 @@ "LOG = pd.DataFrame(results_cp)\n", "eval_columns = LOG.columns[4:]\n", "\n", - "home_path = f\"{home}/data/metabolites/benchmarking/\"\n", - "NAME = model_params[\"gnn_type\"] + \"_cp_Jan24_test.csv\"\n", + "home_path = f'{home}/data/metabolites/benchmarking/'\n", + "NAME = model_params['gnn_type'] + '_cp_Jan24_test.csv'\n", "for col in eval_columns:\n", " LOG[col] = LOG[col].apply(lambda x: str(list(x)))\n", - "LOG.to_csv(home_path + NAME, index=False, sep=\"\\t\")" + "LOG.to_csv(home_path + NAME, index=False, sep='\\t')" ] }, { @@ -3355,9 +3358,9 @@ "metadata": {}, "outputs": [], "source": [ - "LOGIC = pd.read_csv(home_path + NAME, sep=\"\\t\")\n", + "LOGIC = pd.read_csv(home_path + NAME, sep='\\t')\n", "for col in eval_columns:\n", - " LOGIC[col] = LOGIC[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " LOGIC[col] = LOGIC[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", "# LOGIC[eval_columns].apply(lambda x: x.apply(np.median))" ] }, diff --git a/notebooks/grid_stats.ipynb b/notebooks/grid_stats.ipynb index 5fc3b48..b78c51b 100644 --- a/notebooks/grid_stats.ipynb +++ b/notebooks/grid_stats.ipynb @@ -24,6 +24,7 @@ ], "source": [ "import sys\n", + "\n", "import torch\n", "\n", "seed = 42\n", @@ -32,36 +33,35 @@ "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", "\n", - "import pandas as pd\n", - "import numpy as np\n", "import ast\n", "import copy\n", "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", + "home = expanduser('~')\n", "from fiora.IO.LibraryLoader import LibraryLoader\n", + "from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM, PPM\n", "from fiora.MOL.FragmentationTree import FragmentationTree\n", - "from fiora.visualization.define_colors import set_light_theme\n", "from fiora.visualization.define_colors import *\n", + "from fiora.visualization.define_colors import set_light_theme\n", "\n", "set_light_theme()\n", - "import fiora.visualization.spectrum_visualizer as sv\n", - "\n", - "\n", "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "\n", - "from sklearn.metrics import r2_score\n", "import scipy\n", + "import seaborn as sns\n", "from rdkit import RDLogger\n", + "from sklearn.metrics import r2_score\n", + "\n", + "import fiora.visualization.spectrum_visualizer as sv\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f\"Working with Python {sys.version}\")" + "print(f'Working with Python {sys.version}')" ] }, { @@ -79,11 +79,11 @@ "outputs": [], "source": [ "def read_log(file):\n", - " LOG = pd.read_csv(file, sep=\"\\t\")\n", + " LOG = pd.read_csv(file, sep='\\t')\n", " eval_columns = LOG.columns[3:]\n", "\n", " for col in eval_columns:\n", - " LOG[col] = LOG[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " LOG[col] = LOG[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", " return LOG" ] }, @@ -93,9 +93,9 @@ "metadata": {}, "outputs": [], "source": [ - "path = f\"{home}/data/metabolites/benchmarking/\"\n", - "suffix = \"_cp_Jan24.csv\"\n", - "NAMES = [\"GraphConv\", \"RGCNConv\", \"GAT\", \"TransformerConv\"] # , \"CGConv_depth.csv\",\n", + "path = f'{home}/data/metabolites/benchmarking/'\n", + "suffix = '_cp_Jan24.csv'\n", + "NAMES = ['GraphConv', 'RGCNConv', 'GAT', 'TransformerConv'] # , \"CGConv_depth.csv\",\n", "\n", "log = []\n", "for name in NAMES:\n", @@ -1088,11 +1088,11 @@ "outputs": [], "source": [ "gnn_type_labels = {\n", - " \"GraphConv\": \"GCN\",\n", - " \"CGConv\": \"CGC\",\n", - " \"GAT\": \"GAT\",\n", - " \"RGCNConv\": \"RGCN\",\n", - " \"TransformerConv\": \"Transformer\",\n", + " 'GraphConv': 'GCN',\n", + " 'CGConv': 'CGC',\n", + " 'GAT': 'GAT',\n", + " 'RGCNConv': 'RGCN',\n", + " 'TransformerConv': 'Transformer',\n", "}" ] }, @@ -1113,7 +1113,7 @@ } ], "source": [ - "sns.palplot(sns.color_palette(\"colorblind\"))\n", + "sns.palplot(sns.color_palette('colorblind'))\n", "plt.show()" ] }, @@ -1138,51 +1138,51 @@ "plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", "color_blind4 = [\n", - " sns.color_palette(\"colorblind\")[4],\n", - " sns.color_palette(\"colorblind\")[0],\n", - " sns.color_palette(\"colorblind\")[2],\n", - " sns.color_palette(\"colorblind\")[3],\n", + " sns.color_palette('colorblind')[4],\n", + " sns.color_palette('colorblind')[0],\n", + " sns.color_palette('colorblind')[2],\n", + " sns.color_palette('colorblind')[3],\n", "]\n", - "L = log.explode(\"validation\")\n", - "L[\"gnn_type\"] = L[\"gnn_type\"].map(gnn_type_labels)\n", + "L = log.explode('validation')\n", + "L['gnn_type'] = L['gnn_type'].map(gnn_type_labels)\n", "sns.pointplot(\n", " data=L,\n", - " x=\"depth\",\n", - " y=\"validation\",\n", - " estimator=\"median\",\n", + " x='depth',\n", + " y='validation',\n", + " estimator='median',\n", " capsize=0.0,\n", - " markers=\"o\",\n", + " markers='o',\n", " palette=color_blind4,\n", " markersize=5,\n", - " errorbar=(\"ci\", 95),\n", - " linestyles=\"--\",\n", - " hue=\"gnn_type\",\n", + " errorbar=('ci', 95),\n", + " linestyles='--',\n", + " hue='gnn_type',\n", " dodge=0.4,\n", ") # ci=('ci', 0.95),, palette=tri_palette, bins=[x/10.0 for x in list(range(0,10,1))]), multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", "# plt.ylim([0.4, 0.8]) # markers=[\"o\", \"X\", \"^\",\"*\"]\n", - "plt.ylabel(\"Cosine similarity\")\n", - "plt.xlabel(\"Graph network depth\")\n", + "plt.ylabel('Cosine similarity')\n", + "plt.xlabel('Graph network depth')\n", "plt.yticks(np.arange(0.5, 0.91, 0.05))\n", "# plt.autoscale(enable=True, axis='y')\n", "plt.ylim(0.575, 0.87)\n", "# plt.ylim(0.25, 0.95)\n", - "plt.legend(title=\"\")\n", + "plt.legend(title='')\n", "\n", "# for line in ax.lines:\n", "# marker = line.get_marker()\n", "# line.set_markeredgecolor('black')\n", "\n", - "plt.rc(\"axes\", labelsize=14)\n", - "plt.rc(\"legend\", fontsize=14)\n", - "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "plt.rc('axes', labelsize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "ax.tick_params(axis='both', which='major', labelsize=13)\n", "ax.text(\n", " 0.02,\n", " 0.02,\n", - " \"n=7212 for all data points\",\n", + " 'n=7212 for all data points',\n", " transform=ax.transAxes,\n", " fontsize=13,\n", - " va=\"bottom\",\n", - " ha=\"left\",\n", + " va='bottom',\n", + " ha='left',\n", ")\n", "\n", "# fig.savefig(f\"{home}/images/paper/grid_params3.svg\", format=\"svg\", dpi=600)\n", @@ -1212,50 +1212,50 @@ "plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", "color_blind4 = [\n", - " sns.color_palette(\"colorblind\")[4],\n", - " sns.color_palette(\"colorblind\")[0],\n", - " sns.color_palette(\"colorblind\")[2],\n", - " sns.color_palette(\"colorblind\")[3],\n", + " sns.color_palette('colorblind')[4],\n", + " sns.color_palette('colorblind')[0],\n", + " sns.color_palette('colorblind')[2],\n", + " sns.color_palette('colorblind')[3],\n", "]\n", - "L = log.explode(\"validation\")\n", - "L[\"gnn_type\"] = L[\"gnn_type\"].map(gnn_type_labels)\n", + "L = log.explode('validation')\n", + "L['gnn_type'] = L['gnn_type'].map(gnn_type_labels)\n", "sns.pointplot(\n", " data=L,\n", - " x=\"depth\",\n", - " y=\"validation\",\n", - " estimator=\"median\",\n", + " x='depth',\n", + " y='validation',\n", + " estimator='median',\n", " capsize=0.0,\n", - " markers=\"o\",\n", + " markers='o',\n", " palette=color_blind4,\n", " markersize=5,\n", - " errorbar=(\"pi\", 50),\n", - " linestyles=\"--\",\n", - " hue=\"gnn_type\",\n", + " errorbar=('pi', 50),\n", + " linestyles='--',\n", + " hue='gnn_type',\n", " dodge=0.4,\n", ") # ci=('ci', 0.95),, palette=tri_palette, bins=[x/10.0 for x in list(range(0,10,1))]), multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", "# plt.ylim([0.4, 0.8]) # markers=[\"o\", \"X\", \"^\",\"*\"]\n", - "plt.ylabel(\"Cosine similarity\")\n", - "plt.xlabel(\"Graph network depth\")\n", + "plt.ylabel('Cosine similarity')\n", + "plt.xlabel('Graph network depth')\n", "plt.yticks(np.arange(0.4, 0.91, 0.05))\n", "# plt.autoscale(enable=True, axis='y')\n", "plt.ylim(0.4, 0.95)\n", "# plt.ylim(0.25, 0.95)\n", - "plt.legend(title=\"\")\n", + "plt.legend(title='')\n", "\n", "# for line in ax.lines:\n", "# marker = line.get_marker()\n", "# line.set_markeredgecolor('black')\n", - "plt.rc(\"axes\", labelsize=14)\n", - "plt.rc(\"legend\", fontsize=14)\n", - "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "plt.rc('axes', labelsize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "ax.tick_params(axis='both', which='major', labelsize=13)\n", "ax.text(\n", " 0.02,\n", " 0.02,\n", - " \"n=7212 for all data points\",\n", + " 'n=7212 for all data points',\n", " transform=ax.transAxes,\n", " fontsize=13,\n", - " va=\"bottom\",\n", - " ha=\"left\",\n", + " va='bottom',\n", + " ha='left',\n", ")\n", "\n", "# fig.savefig(f\"{home}/images/paper/grid_params3_iqr.svg\", format=\"svg\", dpi=600)\n", @@ -1276,15 +1276,15 @@ "\n", "\n", "def print_stats(scores):\n", - " print(f\"Num of spectra: {len(scores)}\")\n", - " print(f\"Median:\\t{np.median(scores):.3f}\")\n", - " print(f\"Var:\\t{np.var(scores):.3f} (Standard deviation: {np.std(scores):.3f})\")\n", + " print(f'Num of spectra: {len(scores)}')\n", + " print(f'Median:\\t{np.median(scores):.3f}')\n", + " print(f'Var:\\t{np.var(scores):.3f} (Standard deviation: {np.std(scores):.3f})')\n", " conf_in_t = st.t.interval(\n", " confidence=0.95, df=len(scores) - 1, loc=np.median(scores), scale=st.sem(scores)\n", " )\n", " conf_in_boot = bootstrap((scores,), np.median, confidence_level=0.95)\n", - " print(f\"95%CI: {conf_in_t} (from t distribution)\")\n", - " print(f\"95%CI: {conf_in_boot.confidence_interval} (from bootstrapping)\")" + " print(f'95%CI: {conf_in_t} (from t distribution)')\n", + " print(f'95%CI: {conf_in_boot.confidence_interval} (from bootstrapping)')" ] }, { @@ -1305,7 +1305,7 @@ } ], "source": [ - "val_scores = L[(L[\"gnn_type\"] == \"RGCN\") & (L[\"depth\"] == 6)][\"validation\"]\n", + "val_scores = L[(L['gnn_type'] == 'RGCN') & (L['depth'] == 6)]['validation']\n", "print_stats(list(val_scores))" ] }, @@ -1480,7 +1480,7 @@ ], "source": [ "eval_columns = log.columns[4:]\n", - "log[log[\"gnn_type\"] == \"RGCNConv\"][eval_columns].apply(\n", + "log[log['gnn_type'] == 'RGCNConv'][eval_columns].apply(\n", " lambda x: x.apply(np.mean), axis=1\n", ")" ] @@ -1510,13 +1510,13 @@ } ], "source": [ - "log4 = log[(log[\"gnn_type\"] == \"GraphConv\") & (log[\"depth\"] == 4)]\n", - "L4 = log4.explode(\"validation\")\n", - "sns.histplot(data=L4, x=\"validation\")\n", + "log4 = log[(log['gnn_type'] == 'GraphConv') & (log['depth'] == 4)]\n", + "L4 = log4.explode('validation')\n", + "sns.histplot(data=L4, x='validation')\n", "plt.show()\n", "\n", - "print(L4[\"validation\"].median())\n", - "print(np.std(L4[\"validation\"]))" + "print(L4['validation'].median())\n", + "print(np.std(L4['validation']))" ] }, { @@ -1546,7 +1546,7 @@ } ], "source": [ - "sns.boxplot(data=L4, y=\"validation\")" + "sns.boxplot(data=L4, y='validation')" ] }, { @@ -2531,7 +2531,7 @@ "metadata": {}, "outputs": [], "source": [ - "L_export = L[[\"depth\", \"gnn_type\", \"validation\"]]" + "L_export = L[['depth', 'gnn_type', 'validation']]" ] }, { @@ -2542,9 +2542,9 @@ "source": [ "L_export = L_export.rename(\n", " columns={\n", - " \"depth\": \"Depth\",\n", - " \"gnn_type\": \"GNN Architecture\",\n", - " \"validation\": \"Cosine Similarity\",\n", + " 'depth': 'Depth',\n", + " 'gnn_type': 'GNN Architecture',\n", + " 'validation': 'Cosine Similarity',\n", " }\n", ")" ] @@ -2555,8 +2555,8 @@ "metadata": {}, "outputs": [], "source": [ - "L_export[\"Validation ID\"] = (\n", - " L_export.groupby([\"Depth\", \"GNN Architecture\"]).cumcount() + 1\n", + "L_export['Validation ID'] = (\n", + " L_export.groupby(['Depth', 'GNN Architecture']).cumcount() + 1\n", ")" ] }, @@ -2575,8 +2575,8 @@ "metadata": {}, "outputs": [], "source": [ - "L_export[[\"Depth\", \"GNN Architecture\", \"Validation ID\", \"Cosine Similarity\"]].to_excel(\n", - " f\"{home}/images/paper/SourceData_Figure2.xlsx\"\n", + "L_export[['Depth', 'GNN Architecture', 'Validation ID', 'Cosine Similarity']].to_excel(\n", + " f'{home}/images/paper/SourceData_Figure2.xlsx'\n", ")" ] } diff --git a/notebooks/info_graphs.ipynb b/notebooks/info_graphs.ipynb index 9e24c71..6eee47b 100644 --- a/notebooks/info_graphs.ipynb +++ b/notebooks/info_graphs.ipynb @@ -24,6 +24,7 @@ ], "source": [ "import sys\n", + "\n", "import torch\n", "\n", "seed = 42\n", @@ -32,28 +33,29 @@ "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", "\n", - "import pandas as pd\n", - "import numpy as np\n", "import ast\n", "import copy\n", "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", - "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree\n", - "import fiora.visualization.spectrum_visualizer as sv\n", - "\n", - "from sklearn.metrics import r2_score\n", + "home = expanduser('~')\n", "import scipy\n", "from rdkit import RDLogger\n", + "from sklearn.metrics import r2_score\n", + "\n", + "import fiora.visualization.spectrum_visualizer as sv\n", + "from fiora.IO.LibraryLoader import LibraryLoader\n", + "from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM, PPM\n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f\"Working with Python {sys.version}\")" + "print(f'Working with Python {sys.version}')" ] }, { @@ -72,13 +74,13 @@ "source": [ "from typing import Literal\n", "\n", - "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\", \"MSnLib\"] = \"MSnLib\" # \"MSnLib\"\n", - "print(f\"Preparing {lib} library\")\n", + "lib: Literal['NIST', 'MSDIAL', 'NIST/MSDIAL', 'MSnLib'] = 'MSnLib' # \"MSnLib\"\n", + "print(f'Preparing {lib} library')\n", "\n", "debug_mode = False # Default: False\n", "if debug_mode:\n", " print(\n", - " \"+++ This is a test run (debug mode) with a small subset of data points. Results are not representative. +++\"\n", + " '+++ This is a test run (debug mode) with a small subset of data points. Results are not representative. +++'\n", " )" ] }, @@ -90,14 +92,14 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\",\n", + " 'name': 'Name',\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'ionization': 'Ionization',\n", + " 'precursor_mz': 'PrecursorMZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", + " 'ccs': 'CCS',\n", "}\n", "\n", "\n", @@ -107,14 +109,14 @@ "\n", "\n", "def load_training_data():\n", - " if \"NIST\" in lib or \"MSDIAL\" in lib:\n", - " data_path: str = f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\"\n", - " elif lib == \"MSnLib\":\n", + " if 'NIST' in lib or 'MSDIAL' in lib:\n", + " data_path: str = f'{home}/data/metabolites/preprocessed/datasplits_Jan24.csv'\n", + " elif lib == 'MSnLib':\n", " data_path: str = (\n", - " f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_April25_v1.csv\"\n", + " f'{home}/data/metabolites/preprocessed/datasplits_msnlib_April25_v1.csv'\n", " )\n", " else:\n", - " raise NameError(f\"Unknown library selected {lib=}.\")\n", + " raise NameError(f'Unknown library selected {lib=}.')\n", " L = LibraryLoader()\n", " df = L.load_from_csv(data_path)\n", " return df\n", @@ -123,12 +125,12 @@ "df = load_training_data()\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"summary\"]\n", + "dict_columns = ['peaks', 'summary']\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", " # df[col] = df[col].apply(ast.literal_eval)\n", "\n", - "df[\"group_id\"] = df[\"group_id\"].astype(int)" + "df['group_id'] = df['group_id'].astype(int)" ] }, { @@ -137,11 +139,10 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MOL.Metabolite import Metabolite\n", "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", "from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder\n", - "\n", + "from fiora.MOL.Metabolite import Metabolite\n", "\n", "CE_upper_limit = 100.0\n", "weight_upper_limit = 1000.0\n", @@ -152,51 +153,51 @@ " # df = df.iloc[5000:20000,:]\n", "\n", "overwrite_setup_features = None\n", - "if lib == \"MSnLib\":\n", + "if lib == 'MSnLib':\n", " overwrite_setup_features = {\n", - " \"instrument\": [\"HCD\"],\n", - " \"precursor_mode\": [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"],\n", + " 'instrument': ['HCD'],\n", + " 'precursor_mode': ['[M+H]+', '[M-H]-', '[M]+', '[M]-'],\n", " }\n", "\n", "\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", - "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df['Metabolite'] = df['SMILES'].apply(Metabolite)\n", + "df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "covariate_encoder = CovariateFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", - " \"element_composition\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", + " 'element_composition',\n", " ],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "rt_encoder = CovariateFeatureEncoder(\n", " feature_list=[\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", - " \"element_composition\",\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", + " 'element_composition',\n", " ],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "\n", - "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "covariate_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + "covariate_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + "rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", "\n", - "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", + "df['Metabolite'].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", + "df.apply(lambda x: x['Metabolite'].set_id(x['group_id']), axis=1)\n", "\n", "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], covariate_encoder, rt_encoder),\n", " axis=1,\n", ")\n", - "_ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)" + "_ = df.apply(lambda x: x['Metabolite'].set_loss_weight(x['loss_weight']), axis=1)" ] }, { @@ -208,7 +209,7 @@ "from fiora.MOL.MetaboliteIndex import MetaboliteIndex\n", "\n", "mindex: MetaboliteIndex = MetaboliteIndex()\n", - "mindex.index_metabolites(df[\"Metabolite\"])" + "mindex.index_metabolites(df['Metabolite'])" ] }, { @@ -227,9 +228,9 @@ "source": [ "mindex.create_fragmentation_trees()\n", "list_of_mismatched_ids = mindex.add_fragmentation_trees_to_metabolite_list(\n", - " df[\"Metabolite\"], graph_mismatch_policy=\"recompute\"\n", + " df['Metabolite'], graph_mismatch_policy='recompute'\n", ")\n", - "print(f\"Total number of recomputed trees: {len(list_of_mismatched_ids)}\")" + "print(f'Total number of recomputed trees: {len(list_of_mismatched_ids)}')" ] }, { @@ -240,8 +241,8 @@ "source": [ "# df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", "_ = df.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=x['ppm_peak_tolerance']\n", " ),\n", " axis=1,\n", ")" @@ -293,17 +294,17 @@ "from fiora.MOL.MetaboliteDatasetStatistics import MetaboliteDatasetStatistics\n", "\n", "ORDERED_ELEMENT_LIST = [\n", - " \"C\",\n", - " \"H\",\n", - " \"O\",\n", - " \"N\",\n", - " \"F\",\n", - " \"Cl\",\n", - " \"Br\",\n", - " \"I\",\n", - " \"P\",\n", - " \"S\",\n", - " \"Si\",\n", + " 'C',\n", + " 'H',\n", + " 'O',\n", + " 'N',\n", + " 'F',\n", + " 'Cl',\n", + " 'Br',\n", + " 'I',\n", + " 'P',\n", + " 'S',\n", + " 'Si',\n", "] # same as in constants.py, but different order\n", "stats = MetaboliteDatasetStatistics(df)\n", "stats.generate_molecular_statistics()" @@ -346,6 +347,7 @@ "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", + "\n", "from fiora.visualization.define_colors import (\n", " ELEMENT_COLORS,\n", " set_light_theme,\n", @@ -354,18 +356,18 @@ "set_light_theme()\n", "\n", "# Extract total counts for plotting\n", - "total_counts = stats.get_statistics()[\"Molecular Summary\"][\"Total Counts\"]\n", + "total_counts = stats.get_statistics()['Molecular Summary']['Total Counts']\n", "\n", "# Drop Hydrogen completely\n", "filtered_counts = {\n", - " element: count for element, count in total_counts.items() if element != \"H\"\n", + " element: count for element, count in total_counts.items() if element != 'H'\n", "}\n", "\n", "# Define rare elements (everything except C, O, N)\n", "rare_elements = {\n", " element: count\n", " for element, count in filtered_counts.items()\n", - " if element not in [\"C\", \"O\", \"N\"]\n", + " if element not in ['C', 'O', 'N']\n", "}\n", "\n", "# Create the main plot (all elements including rare ones)\n", @@ -375,11 +377,11 @@ " y=list(filtered_counts.values()),\n", " ax=ax_main,\n", " palette=[ELEMENT_COLORS[element] for element in filtered_counts.keys()],\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", ")\n", - "ax_main.set_title(\"Element Composition\")\n", - "ax_main.set_xlabel(\"Element\")\n", - "ax_main.set_ylabel(\"Total Count\")\n", + "ax_main.set_title('Element Composition')\n", + "ax_main.set_xlabel('Element')\n", + "ax_main.set_ylabel('Total Count')\n", "\n", "# Create the zoomed-in plot for rare elements\n", "ax_zoom_loc = [0.5, 0.4, 0.65, 0.45] # [x, y, width, height] in relative coordinates\n", @@ -389,11 +391,11 @@ " y=list(rare_elements.values()),\n", " ax=ax_zoom,\n", " palette=[ELEMENT_COLORS[element] for element in rare_elements.keys()],\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", ")\n", - "ax_zoom.set_title(\"Rare Elements\")\n", - "ax_zoom.set_xlabel(\"\")\n", - "ax_zoom.set_ylabel(\"\")\n", + "ax_zoom.set_title('Rare Elements')\n", + "ax_zoom.set_xlabel('')\n", + "ax_zoom.set_ylabel('')\n", "\n", "# # Get the position of Fluorine (F) in the main plot\n", "# f_index_main = list(filtered_counts.keys()).index(\"F\")\n", @@ -418,12 +420,12 @@ "\n", "\n", "# Get the position of Fluorine (F) in the main plot\n", - "f_index_main = list(filtered_counts.keys()).index(\"F\")\n", - "f_count_main = filtered_counts[\"F\"]\n", + "f_index_main = list(filtered_counts.keys()).index('F')\n", + "f_count_main = filtered_counts['F']\n", "\n", "# Get the position of Fluorine (F) in the zoomed-in subplot\n", - "f_index_zoom = list(rare_elements.keys()).index(\"F\")\n", - "f_count_zoom = rare_elements[\"F\"]\n", + "f_index_zoom = list(rare_elements.keys()).index('F')\n", + "f_count_zoom = rare_elements['F']\n", "\n", "# Calculate the relative position of Fluorine (F) in the subplot\n", "subplot_x_start, subplot_y_start, subplot_width, subplot_height = ax_zoom_loc\n", @@ -440,7 +442,7 @@ "\n", "# Add an arrow from F in the main plot to F in the subplot\n", "ax_main.annotate(\n", - " \"\",\n", + " '',\n", " xy=(\n", " subplot_x_pos,\n", " subplot_y_pos,\n", @@ -449,7 +451,7 @@ " f_index_main,\n", " f_count_main + 7000,\n", " ), # Arrow start point (Fluorine bar in main plot)\n", - " arrowprops=dict(facecolor=\"black\", edgecolor=\"black\", arrowstyle=\"->\"),\n", + " arrowprops=dict(facecolor='black', edgecolor='black', arrowstyle='->'),\n", ")\n", "\n", "# Adjust layout and show the plot\n", @@ -475,16 +477,16 @@ ], "source": [ "# light gray html\n", - "ELEMENT_COLORS[\"ANY_RARE\"] = \"#D3D3D3\" # Light gray for ANY_RARE\n", + "ELEMENT_COLORS['ANY_RARE'] = '#D3D3D3' # Light gray for ANY_RARE\n", "\n", "# Extract presence probabilities for rare elements\n", - "presence_probabilities = stats.get_statistics()[\"Molecular Summary\"][\n", - " \"Presence Probabilities\"\n", + "presence_probabilities = stats.get_statistics()['Molecular Summary'][\n", + " 'Presence Probabilities'\n", "]\n", "rare_element_probabilities = {\n", " element: prob\n", " for element, prob in presence_probabilities.items()\n", - " if element not in [\"C\", \"O\", \"N\", \"H\"]\n", + " if element not in ['C', 'O', 'N', 'H']\n", "}\n", "\n", "# Create the probability plot\n", @@ -495,20 +497,20 @@ " hue=list(rare_element_probabilities.keys()),\n", " ax=ax_prob,\n", " palette=[ELEMENT_COLORS[element] for element in rare_element_probabilities.keys()],\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " legend=False,\n", ")\n", "\n", "# Apply hatching manually for ANY_RARE\n", "for bar, element in zip(bars.patches, rare_element_probabilities.keys()):\n", - " if element == \"ANY_RARE\":\n", - " bar.set_hatch(\"///\") # Apply hatching for ANY_RARE\n", + " if element == 'ANY_RARE':\n", + " bar.set_hatch('///') # Apply hatching for ANY_RARE\n", "\n", "\n", - "ax_prob.set_title(\"Probability of Rare Element Occurrence\")\n", + "ax_prob.set_title('Probability of Rare Element Occurrence')\n", "ax_prob.set_ylim(0, 0.65) # Set y-axis limits to [0, 1] for probabilities\n", - "ax_prob.set_xlabel(\"Element\")\n", - "ax_prob.set_ylabel(\"Probability\")\n", + "ax_prob.set_xlabel('Element')\n", + "ax_prob.set_ylabel('Probability')\n", "\n", "# Adjust layout and show the plot\n", "plt.tight_layout()\n", @@ -642,13 +644,13 @@ "source": [ "# Find a large molecule with S in the structure\n", "large_molecule_with_s = df[\n", - " df[\"Metabolite\"].apply(lambda x: \"S\" in x.node_elements and x.ExactMolWeight > 500)\n", + " df['Metabolite'].apply(lambda x: 'S' in x.node_elements and x.ExactMolWeight > 500)\n", "].iloc[0]\n", "very_large_molecule_with_s = df[\n", - " df[\"Metabolite\"].apply(lambda x: \"S\" in x.node_elements and x.ExactMolWeight > 900)\n", + " df['Metabolite'].apply(lambda x: 'S' in x.node_elements and x.ExactMolWeight > 900)\n", "].iloc[2]\n", "\n", - "large_molecule_with_s[\"Metabolite\"].draw(high_res=True)" + "large_molecule_with_s['Metabolite'].draw(high_res=True)" ] }, { @@ -858,7 +860,7 @@ } ], "source": [ - "very_large_molecule_with_s[\"Metabolite\"].draw(high_res=True)" + "very_large_molecule_with_s['Metabolite'].draw(high_res=True)" ] }, { @@ -879,11 +881,11 @@ "\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", - " dev = \"cuda:0\"\n", + " dev = 'cuda:0'\n", "else:\n", - " dev = \"cpu\"\n", + " dev = 'cpu'\n", "\n", - "print(f\"Running on device: {dev}\")" + "print(f'Running on device: {dev}')" ] }, { @@ -892,20 +894,19 @@ "metadata": {}, "outputs": [], "source": [ - "MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.1.1_OS_depth12_June25.pt\"\n", + "MODEL_PATH = f'{home}/data/metabolites/pretrained_models/v0.1.1_OS_depth12_June25.pt'\n", "OLD_MODEL_PATH = (\n", - " f\"{home}/data/metabolites/pretrained_models/v0.1.1_OS_depth5_June25_ls1.pt\"\n", + " f'{home}/data/metabolites/pretrained_models/v0.1.1_OS_depth5_June25_ls1.pt'\n", ")\n", "\n", "from fiora.GNN.FioraModel import FioraModel\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", - "\n", "try:\n", " model = FioraModel.load_from_state_dict(MODEL_PATH)\n", " old_model = FioraModel.load_from_state_dict(OLD_MODEL_PATH)\n", "except:\n", - " raise NameError(\"Error: Failed loading from state dict.\")" + " raise NameError('Error: Failed loading from state dict.')" ] }, { @@ -1001,20 +1002,19 @@ "source": [ "from fiora.visualization.spectrum_visualizer import plot_spectrum\n", "\n", - "\n", "large_stats = fiora.simulate_and_score(\n", - " large_molecule_with_s[\"Metabolite\"],\n", + " large_molecule_with_s['Metabolite'],\n", " model,\n", - " query_peaks=large_molecule_with_s[\"peaks\"],\n", + " query_peaks=large_molecule_with_s['peaks'],\n", ")\n", "print(\n", - " \"Cosine large (new model)\",\n", - " large_stats[\"spectral_sqrt_cosine\"],\n", - " \" and \",\n", - " large_stats[\"spectral_sqrt_cosine_wo_prec\"],\n", + " 'Cosine large (new model)',\n", + " large_stats['spectral_sqrt_cosine'],\n", + " ' and ',\n", + " large_stats['spectral_sqrt_cosine_wo_prec'],\n", ")\n", "plot_spectrum(\n", - " large_molecule_with_s, {\"peaks\": large_stats[\"sim_peaks\"]}, highlight_matches=True\n", + " large_molecule_with_s, {'peaks': large_stats['sim_peaks']}, highlight_matches=True\n", ")" ] }, @@ -1054,51 +1054,51 @@ "source": [ "copy_large_molecule_with_s = copy.deepcopy(large_molecule_with_s)\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "covariate_encoder = CovariateFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", " ],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "rt_encoder = CovariateFeatureEncoder(\n", - " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " feature_list=['molecular_weight', 'precursor_mode', 'instrument'],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "\n", - "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "covariate_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + "covariate_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + "rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", "\n", - "copy_large_molecule_with_s[\"Metabolite\"].compute_graph_attributes(\n", + "copy_large_molecule_with_s['Metabolite'].compute_graph_attributes(\n", " node_encoder, bond_encoder\n", ")\n", "\n", "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "copy_large_molecule_with_s[\"Metabolite\"].add_metadata(\n", - " copy_large_molecule_with_s[\"summary\"], covariate_encoder, rt_encoder\n", + "copy_large_molecule_with_s['Metabolite'].add_metadata(\n", + " copy_large_molecule_with_s['summary'], covariate_encoder, rt_encoder\n", ")\n", "# _ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)\n", "\n", "\n", "large_stats_old = fiora.simulate_and_score(\n", - " copy_large_molecule_with_s[\"Metabolite\"],\n", + " copy_large_molecule_with_s['Metabolite'],\n", " old_model,\n", - " query_peaks=copy_large_molecule_with_s[\"peaks\"],\n", + " query_peaks=copy_large_molecule_with_s['peaks'],\n", ")\n", "print(\n", - " f\"Cosine large (old model)\",\n", - " large_stats_old[\"spectral_sqrt_cosine\"],\n", - " \"and\",\n", - " large_stats_old[\"spectral_sqrt_cosine_wo_prec\"],\n", + " 'Cosine large (old model)',\n", + " large_stats_old['spectral_sqrt_cosine'],\n", + " 'and',\n", + " large_stats_old['spectral_sqrt_cosine_wo_prec'],\n", ")\n", "plot_spectrum(\n", " copy_large_molecule_with_s,\n", - " {\"peaks\": large_stats_old[\"sim_peaks\"]},\n", + " {'peaks': large_stats_old['sim_peaks']},\n", " highlight_matches=True,\n", ")" ] @@ -1138,19 +1138,19 @@ ], "source": [ "very_large_stats = fiora.simulate_and_score(\n", - " very_large_molecule_with_s[\"Metabolite\"],\n", + " very_large_molecule_with_s['Metabolite'],\n", " model,\n", - " query_peaks=very_large_molecule_with_s[\"peaks\"],\n", + " query_peaks=very_large_molecule_with_s['peaks'],\n", ")\n", "print(\n", - " \"Cosine very large (new model)\",\n", - " very_large_stats[\"spectral_sqrt_cosine\"],\n", - " \"and\",\n", - " very_large_stats[\"spectral_sqrt_cosine_wo_prec\"],\n", + " 'Cosine very large (new model)',\n", + " very_large_stats['spectral_sqrt_cosine'],\n", + " 'and',\n", + " very_large_stats['spectral_sqrt_cosine_wo_prec'],\n", ")\n", "plot_spectrum(\n", " very_large_molecule_with_s,\n", - " {\"peaks\": very_large_stats[\"sim_peaks\"]},\n", + " {'peaks': very_large_stats['sim_peaks']},\n", " highlight_matches=True,\n", ")" ] @@ -1191,51 +1191,51 @@ "source": [ "copy_very_large_molecule_with_s = copy.deepcopy(very_large_molecule_with_s)\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "covariate_encoder = CovariateFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", " ],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "rt_encoder = CovariateFeatureEncoder(\n", - " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " feature_list=['molecular_weight', 'precursor_mode', 'instrument'],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "\n", - "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "covariate_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + "covariate_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + "rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", "\n", - "copy_very_large_molecule_with_s[\"Metabolite\"].compute_graph_attributes(\n", + "copy_very_large_molecule_with_s['Metabolite'].compute_graph_attributes(\n", " node_encoder, bond_encoder\n", ")\n", "\n", "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", - "copy_very_large_molecule_with_s[\"Metabolite\"].add_metadata(\n", - " copy_very_large_molecule_with_s[\"summary\"], covariate_encoder, rt_encoder\n", + "copy_very_large_molecule_with_s['Metabolite'].add_metadata(\n", + " copy_very_large_molecule_with_s['summary'], covariate_encoder, rt_encoder\n", ")\n", "# _ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)\n", "\n", "\n", "very_large_stats_old = fiora.simulate_and_score(\n", - " copy_very_large_molecule_with_s[\"Metabolite\"],\n", + " copy_very_large_molecule_with_s['Metabolite'],\n", " old_model,\n", - " query_peaks=copy_very_large_molecule_with_s[\"peaks\"],\n", + " query_peaks=copy_very_large_molecule_with_s['peaks'],\n", ")\n", "print(\n", - " f\"Cosine very large (old model)\",\n", - " very_large_stats_old[\"spectral_sqrt_cosine\"],\n", - " \"and \",\n", - " very_large_stats_old[\"spectral_sqrt_cosine_wo_prec\"],\n", + " 'Cosine very large (old model)',\n", + " very_large_stats_old['spectral_sqrt_cosine'],\n", + " 'and ',\n", + " very_large_stats_old['spectral_sqrt_cosine_wo_prec'],\n", ")\n", "plot_spectrum(\n", " copy_very_large_molecule_with_s,\n", - " {\"peaks\": very_large_stats_old[\"sim_peaks\"]},\n", + " {'peaks': very_large_stats_old['sim_peaks']},\n", " highlight_matches=True,\n", ")" ] diff --git a/notebooks/live_predict.ipynb b/notebooks/live_predict.ipynb index 705eb81..16ed7dd 100644 --- a/notebooks/live_predict.ipynb +++ b/notebooks/live_predict.ipynb @@ -15,8 +15,8 @@ ], "source": [ "import sys\n", - "import torch\n", "\n", + "import torch\n", "\n", "seed = 42\n", "# torch.set_default_dtype(torch.float64)\n", @@ -24,39 +24,38 @@ "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", "\n", - "import pandas as pd\n", - "import numpy as np\n", "import ast\n", "import copy\n", + "\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", "import seaborn as sns\n", "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", - "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree\n", - "import fiora.visualization.spectrum_visualizer as sv\n", + "home = expanduser('~')\n", "import json\n", - "from fiora.GNN.GNNModules import GNNCompiler\n", - "from fiora.MS.SimulationFramework import SimulationFramework\n", - "from fiora.GNN.GNNModules import GNNCompiler\n", - "from fiora.MS.SimulationFramework import SimulationFramework\n", - "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", - "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", - "from fiora.GNN.SetupFeatureEncoder import SetupFeatureEncoder\n", - "import fiora.visualization.spectrum_visualizer as sv\n", "\n", - "from sklearn.metrics import r2_score\n", "import scipy\n", + "from fiora.GNN.GNNModules import GNNCompiler\n", + "from fiora.GNN.SetupFeatureEncoder import SetupFeatureEncoder\n", "from rdkit import RDLogger\n", + "from sklearn.metrics import r2_score\n", + "\n", + "import fiora.visualization.spectrum_visualizer as sv\n", + "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", + "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", + "from fiora.IO.LibraryLoader import LibraryLoader\n", + "from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM, PPM\n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", + "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f\"Working with Python {sys.version}\")" + "print(f'Working with Python {sys.version}')" ] }, { @@ -73,22 +72,22 @@ "outputs": [], "source": [ "depth = 6\n", - "MODEL_PATH = f\"../resources/models/fiora_OS_v0.1.0.pt\"\n", + "MODEL_PATH = '../resources/models/fiora_OS_v0.1.0.pt'\n", "\n", "try:\n", " model = GNNCompiler.load(MODEL_PATH)\n", "except:\n", " try:\n", " print(\n", - " f\"Warning: Failed loading the model {MODEL_PATH}. Fall back: Loading the model from state dictionary.\"\n", + " f'Warning: Failed loading the model {MODEL_PATH}. Fall back: Loading the model from state dictionary.'\n", " )\n", " model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n", - " print(\"Model loaded from state dict without further errors.\")\n", + " print('Model loaded from state dict without further errors.')\n", " except:\n", - " raise NameError(\"Error: Failed loading from state dict.\")\n", + " raise NameError('Error: Failed loading from state dict.')\n", "\n", "\n", - "dev = \"cuda:1\"\n", + "dev = 'cuda:1'\n", "\n", "model.eval()\n", "model = model.to(dev)\n", @@ -115,13 +114,13 @@ } ], "source": [ - "if \"version\" in model.model_params.keys():\n", - " print(f\"\\n-----Model-----\")\n", - " print(model.model_params[\"version\"])\n", - " print(f\"---------------\")\n", - "if \"disclaimer\" in model.model_params.keys():\n", - " dis_msg = model.model_params[\"disclaimer\"]\n", - " print(f\"\\nDisclaimer: {dis_msg}\")" + "if 'version' in model.model_params.keys():\n", + " print('\\n-----Model-----')\n", + " print(model.model_params['version'])\n", + " print('---------------')\n", + "if 'disclaimer' in model.model_params.keys():\n", + " dis_msg = model.model_params['disclaimer']\n", + " print(f'\\nDisclaimer: {dis_msg}')" ] }, { @@ -138,23 +137,23 @@ "metadata": {}, "outputs": [], "source": [ - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "model_setup_feature_sets = None\n", - "if \"setup_features_categorical_set\" in model.model_params.keys():\n", - " model_setup_feature_sets = model.model_params[\"setup_features_categorical_set\"]\n", + "if 'setup_features_categorical_set' in model.model_params.keys():\n", + " model_setup_feature_sets = model.model_params['setup_features_categorical_set']\n", "\n", "setup_encoder = SetupFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", " ],\n", " sets_overwrite=model_setup_feature_sets,\n", ")\n", "rt_encoder = SetupFeatureEncoder(\n", - " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " feature_list=['molecular_weight', 'precursor_mode', 'instrument'],\n", " sets_overwrite=model_setup_feature_sets,\n", ")" ] @@ -181,12 +180,12 @@ "source": [ "from fiora.MOL.Metabolite import Metabolite\n", "\n", - "smiles = \"CC(C)C(CCCN(C)CCC1=CC(=C(C=C1)OC)OC)(C#N)C2=CC(=C(C=C2)OC)OC\"\n", + "smiles = 'CC(C)C(CCCN(C)CCC1=CC(=C(C=C1)OC)OC)(C#N)C2=CC(=C(C=C2)OC)OC'\n", "summary = {\n", - " \"name\": \"Verapamil\",\n", - " \"precursor_mode\": \"[M+H]+\",\n", - " \"collision_energy\": 25.0,\n", - " \"instrument\": \"HCD\",\n", + " 'name': 'Verapamil',\n", + " 'precursor_mode': '[M+H]+',\n", + " 'collision_energy': 25.0,\n", + " 'instrument': 'HCD',\n", "}\n", "\n", "metabolite = Metabolite(smiles)" @@ -249,11 +248,11 @@ ], "source": [ "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12, 3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "img = metabolite.draw(ax=axs[0])\n", - "axs[0].set_title(summary[\"name\"])\n", - "sv.plot_spectrum({\"peaks\": pred[\"sim_peaks\"]}, ax=axs[1])\n", + "axs[0].set_title(summary['name'])\n", + "sv.plot_spectrum({'peaks': pred['sim_peaks']}, ax=axs[1])\n", "plt.show()" ] }, @@ -292,20 +291,20 @@ } ], "source": [ - "summary[\"instrument\"] = \"Q-TOF\"\n", + "summary['instrument'] = 'Q-TOF'\n", "metabolite.add_metadata(summary, setup_encoder, rt_encoder)\n", "pred = fiora.simulate_and_score(metabolite, model=model)\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12, 3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", "img = metabolite.draw(ax=axs[0])\n", - "axs[0].set_title(summary[\"name\"])\n", - "sv.plot_spectrum({\"peaks\": pred[\"sim_peaks\"]}, ax=axs[1])\n", + "axs[0].set_title(summary['name'])\n", + "sv.plot_spectrum({'peaks': pred['sim_peaks']}, ax=axs[1])\n", "plt.show()\n", "\n", - "if model_setup_feature_sets and \"Q-TOF\" not in model_setup_feature_sets[\"instrument\"]:\n", + "if model_setup_feature_sets and 'Q-TOF' not in model_setup_feature_sets['instrument']:\n", " print(\n", - " \"Instrument type: Q-TOF is not a default input of the selected model. The result might not be accurate.\"\n", + " 'Instrument type: Q-TOF is not a default input of the selected model. The result might not be accurate.'\n", " )" ] }, @@ -353,15 +352,15 @@ } ], "source": [ - "summary[\"instrument\"] = \"HCD\"\n", - "energy_levels = {\"low\": 15.0, \"moderate\": 25.0, \"high\": 35.0}\n", + "summary['instrument'] = 'HCD'\n", + "energy_levels = {'low': 15.0, 'moderate': 25.0, 'high': 35.0}\n", "for level, ce in energy_levels.items():\n", - " summary[\"collision_energy\"] = ce\n", + " summary['collision_energy'] = ce\n", " metabolite.add_metadata(summary, setup_encoder, rt_encoder)\n", " pred = fiora.simulate_and_score(metabolite, model=model)\n", " fig, ax = plt.subplots(1, 1, figsize=(12, 4))\n", - " sv.plot_spectrum({\"peaks\": pred[\"sim_peaks\"]}, ax=ax)\n", - " plt.title(\"Collision energy level: \" + r\"$\\bf{\" f\"{level}\" + \"}$\" + f\" ({ce} eV)\")\n", + " sv.plot_spectrum({'peaks': pred['sim_peaks']}, ax=ax)\n", + " plt.title('Collision energy level: ' + r'$\\bf{' f'{level}' + '}$' + f' ({ce} eV)')\n", " plt.show()" ] }, @@ -397,15 +396,15 @@ } ], "source": [ - "significant_peak_num = np.argmax(pred[\"sim_peaks\"][\"intensity\"])\n", - "fragment_smiles, ion_mode = pred[\"sim_peaks\"][\"annotation\"][significant_peak_num].split(\n", - " \"//\"\n", + "significant_peak_num = np.argmax(pred['sim_peaks']['intensity'])\n", + "fragment_smiles, ion_mode = pred['sim_peaks']['annotation'][significant_peak_num].split(\n", + " '//'\n", ")\n", "print(\n", - " f\"Most significant (non-precursor fragment) {fragment_smiles} found in ionization mode {ion_mode}.\\nThe hydrogen losses suggest a formation of a double bond somewhere in the structure below.\"\n", + " f'Most significant (non-precursor fragment) {fragment_smiles} found in ionization mode {ion_mode}.\\nThe hydrogen losses suggest a formation of a double bond somewhere in the structure below.'\n", ")\n", "Metabolite(fragment_smiles).draw()\n", - "plt.title(f\"{fragment_smiles} ({ion_mode} ionization)\")\n", + "plt.title(f'{fragment_smiles} ({ion_mode} ionization)')\n", "plt.show()" ] } diff --git a/notebooks/sandbox.ipynb b/notebooks/sandbox.ipynb index 64022e6..b8a1211 100644 --- a/notebooks/sandbox.ipynb +++ b/notebooks/sandbox.ipynb @@ -24,6 +24,7 @@ ], "source": [ "import sys\n", + "\n", "import torch\n", "\n", "seed = 42\n", @@ -32,28 +33,29 @@ "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", "\n", - "import pandas as pd\n", - "import numpy as np\n", "import ast\n", "import copy\n", "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", - "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree\n", - "import fiora.visualization.spectrum_visualizer as sv\n", - "\n", - "from sklearn.metrics import r2_score\n", + "home = expanduser('~')\n", "import scipy\n", "from rdkit import RDLogger\n", + "from sklearn.metrics import r2_score\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "import fiora.visualization.spectrum_visualizer as sv\n", + "from fiora.IO.LibraryLoader import LibraryLoader\n", + "from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM, PPM\n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", + "\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f\"Working with Python {sys.version}\")" + "print(f'Working with Python {sys.version}')" ] }, { @@ -81,13 +83,13 @@ "source": [ "from typing import Literal\n", "\n", - "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\", \"MSnLib\"] = \"MSnLib\" # \"MSnLib\"\n", - "print(f\"Preparing {lib} library\")\n", + "lib: Literal['NIST', 'MSDIAL', 'NIST/MSDIAL', 'MSnLib'] = 'MSnLib' # \"MSnLib\"\n", + "print(f'Preparing {lib} library')\n", "\n", "test_run = True # Default: False\n", "if test_run:\n", " print(\n", - " \"+++ This is a test run with a small subset of data points. Results are not representative. +++\"\n", + " '+++ This is a test run with a small subset of data points. Results are not representative. +++'\n", " )" ] }, @@ -99,14 +101,14 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\",\n", + " 'name': 'Name',\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'ionization': 'Ionization',\n", + " 'precursor_mz': 'PrecursorMZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", + " 'ccs': 'CCS',\n", "}\n", "\n", "\n", @@ -116,14 +118,14 @@ "\n", "\n", "def load_training_data():\n", - " if \"NIST\" in lib or \"MSDIAL\" in lib:\n", - " data_path: str = f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\"\n", - " elif lib == \"MSnLib\":\n", + " if 'NIST' in lib or 'MSDIAL' in lib:\n", + " data_path: str = f'{home}/data/metabolites/preprocessed/datasplits_Jan24.csv'\n", + " elif lib == 'MSnLib':\n", " data_path: str = (\n", - " f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_Aug24_v3.csv\"\n", + " f'{home}/data/metabolites/preprocessed/datasplits_msnlib_Aug24_v3.csv'\n", " )\n", " else:\n", - " raise NameError(f\"Unknown library selected {lib=}.\")\n", + " raise NameError(f'Unknown library selected {lib=}.')\n", " L = LibraryLoader()\n", " df = L.load_from_csv(data_path)\n", " return df\n", @@ -132,12 +134,12 @@ "df = load_training_data()\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"summary\"]\n", + "dict_columns = ['peaks', 'summary']\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", " # df[col] = df[col].apply(ast.literal_eval)\n", "\n", - "df[\"group_id\"] = df[\"group_id\"].astype(int)" + "df['group_id'] = df['group_id'].astype(int)" ] }, { @@ -157,11 +159,11 @@ "outputs": [], "source": [ "%%capture\n", - "from fiora.MOL.Metabolite import Metabolite\n", - "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", - "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", "from fiora.GNN.SetupFeatureEncoder import SetupFeatureEncoder\n", "\n", + "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", + "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", + "from fiora.MOL.Metabolite import Metabolite\n", "\n", "CE_upper_limit = 100.0\n", "weight_upper_limit = 1000.0\n", @@ -172,44 +174,44 @@ " # df = df.iloc[5000:20000,:]\n", "\n", "overwrite_setup_features = None\n", - "if lib == \"MSnLib\":\n", + "if lib == 'MSnLib':\n", " overwrite_setup_features = {\n", - " \"instrument\": [\"HCD\"],\n", - " \"precursor_mode\": [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"],\n", + " 'instrument': ['HCD'],\n", + " 'precursor_mode': ['[M+H]+', '[M-H]-', '[M]+', '[M]-'],\n", " }\n", "\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", - "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df['Metabolite'] = df['SMILES'].apply(Metabolite)\n", + "df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "setup_encoder = SetupFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", " ],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "rt_encoder = SetupFeatureEncoder(\n", - " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " feature_list=['molecular_weight', 'precursor_mode', 'instrument'],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "\n", - "setup_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - "setup_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "setup_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + "setup_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + "rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", "\n", - "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", + "df['Metabolite'].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", + "df.apply(lambda x: x['Metabolite'].set_id(x['group_id']), axis=1)\n", "\n", "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", ")\n", - "df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)" + "df.apply(lambda x: x['Metabolite'].set_loss_weight(x['loss_weight']), axis=1)" ] }, { @@ -219,10 +221,10 @@ "outputs": [], "source": [ "%%capture\n", - "df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=x['ppm_peak_tolerance']\n", " ),\n", " axis=1,\n", ")" @@ -252,11 +254,10 @@ } ], "source": [ - "from fiora.MOL.constants import *\n", - "\n", "from rdkit import Chem\n", "from rdkit.Chem import Descriptors\n", "\n", + "from fiora.MOL.constants import *\n", "\n", "Descriptors.ExactMolWt(h_2)" ] @@ -288,7 +289,7 @@ "metadata": {}, "outputs": [], "source": [ - "path: str = f\"{home}/data/metabolites/preprocessed/rings_msnlib.csv\"\n", + "path: str = f'{home}/data/metabolites/preprocessed/rings_msnlib.csv'\n", "df_rings = pd.read_csv(path)" ] }, @@ -299,29 +300,29 @@ "outputs": [], "source": [ "%%capture\n", - "df_rings[\"Metabolite\"] = df_rings[\"SMILES\"].apply(Metabolite)\n", - "df_rings[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - "df_rings[\"Metabolite\"].apply(\n", + "df_rings['Metabolite'] = df_rings['SMILES'].apply(Metabolite)\n", + "df_rings['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_rings['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_rings.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", + "df_rings.apply(lambda x: x['Metabolite'].set_id(x['group_id']), axis=1)\n", "\n", - "dict_columns = [\"peaks\", \"summary\"]\n", + "dict_columns = ['peaks', 'summary']\n", "for col in dict_columns:\n", " df_rings[col] = df_rings[col].apply(\n", - " lambda x: ast.literal_eval(x.replace(\"nan\", \"None\"))\n", + " lambda x: ast.literal_eval(x.replace('nan', 'None'))\n", " )\n", "\n", "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", "df_rings.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", ")\n", - "df_rings.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)\n", - "df_rings[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_rings.apply(lambda x: x['Metabolite'].set_loss_weight(x['loss_weight']), axis=1)\n", + "df_rings['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_rings.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=x['ppm_peak_tolerance']\n", " ),\n", " axis=1,\n", ")" @@ -590,12 +591,12 @@ ], "source": [ "for i, row in (\n", - " df_rings[df_rings[\"Metabolite\"].apply(lambda x: bool(x.ring_proportion == 1.0))]\n", - " .drop_duplicates(subset=[\"group_id\"])\n", + " df_rings[df_rings['Metabolite'].apply(lambda x: bool(x.ring_proportion == 1.0))]\n", + " .drop_duplicates(subset=['group_id'])\n", " .iterrows()\n", "):\n", " print(i)\n", - " row[\"Metabolite\"].draw(show=True)" + " row['Metabolite'].draw(show=True)" ] }, { @@ -627,7 +628,7 @@ } ], "source": [ - "df_rings[\"Metabolite\"].iloc[16].draw(show=True)" + "df_rings['Metabolite'].iloc[16].draw(show=True)" ] }, { @@ -661,7 +662,7 @@ } ], "source": [ - "candidate_2[\"QUALITY_EXPLAINED_INTENSITY\"]" + "candidate_2['QUALITY_EXPLAINED_INTENSITY']" ] }, { @@ -691,19 +692,19 @@ "import matplotlib.pyplot as plt\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(14, 4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(14, 4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", - "img = candidate_1[\"Metabolite\"].draw(ax=axs[0])\n", + "img = candidate_1['Metabolite'].draw(ax=axs[0])\n", "\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", "axs[0].set_title(\n", - " \"Name: \"\n", - " + candidate_1[\"NAME\"]\n", - " + \"\\nCollision energy: \"\n", - " + str(candidate_1[\"CE\"])\n", - " + \" eV\"\n", + " 'Name: '\n", + " + candidate_1['NAME']\n", + " + '\\nCollision energy: '\n", + " + str(candidate_1['CE'])\n", + " + ' eV'\n", ")\n", "\n", "axs[1] = sv.plot_spectrum(\n", @@ -711,14 +712,14 @@ " None,\n", " ax=axs[1],\n", " highlight_matches=True,\n", - " mz_matches=candidate_1[\"Metabolite\"].peak_matches.keys(),\n", + " mz_matches=candidate_1['Metabolite'].peak_matches.keys(),\n", ")\n", "plt.show()\n", "print(\n", " [\n", - " (candidate_1[\"peaks\"][\"mz\"][i], candidate_1[\"peaks\"][\"intensity\"][i])\n", - " for i in range(len((candidate_1[\"peaks\"][\"mz\"])))\n", - " if candidate_1[\"peaks\"][\"intensity\"][i] > 20\n", + " (candidate_1['peaks']['mz'][i], candidate_1['peaks']['intensity'][i])\n", + " for i in range(len(candidate_1['peaks']['mz']))\n", + " if candidate_1['peaks']['intensity'][i] > 20\n", " ]\n", ")" ] @@ -750,19 +751,19 @@ "import matplotlib.pyplot as plt\n", "\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(14, 4), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(14, 4), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", - "img = candidate_2[\"Metabolite\"].draw(ax=axs[0])\n", + "img = candidate_2['Metabolite'].draw(ax=axs[0])\n", "\n", "axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", ")\n", "axs[0].set_title(\n", - " \"Name: \"\n", - " + candidate_2[\"NAME\"]\n", - " + \"\\nCollision energy: \"\n", - " + str(candidate_2[\"CE\"])\n", - " + \" eV\"\n", + " 'Name: '\n", + " + candidate_2['NAME']\n", + " + '\\nCollision energy: '\n", + " + str(candidate_2['CE'])\n", + " + ' eV'\n", ")\n", "\n", "axs[1] = sv.plot_spectrum(\n", @@ -770,14 +771,14 @@ " None,\n", " ax=axs[1],\n", " highlight_matches=True,\n", - " mz_matches=candidate_2[\"Metabolite\"].peak_matches.keys(),\n", + " mz_matches=candidate_2['Metabolite'].peak_matches.keys(),\n", ")\n", "plt.show()\n", "print(\n", " [\n", - " (candidate_2[\"peaks\"][\"mz\"][i], candidate_2[\"peaks\"][\"intensity\"][i])\n", - " for i in range(len((candidate_2[\"peaks\"][\"mz\"])))\n", - " if candidate_2[\"peaks\"][\"intensity\"][i] > 20\n", + " (candidate_2['peaks']['mz'][i], candidate_2['peaks']['intensity'][i])\n", + " for i in range(len(candidate_2['peaks']['mz']))\n", + " if candidate_2['peaks']['intensity'][i] > 20\n", " ]\n", ")" ] @@ -788,9 +789,10 @@ "metadata": {}, "outputs": [], "source": [ - "from periodictable import elements\n", "from typing import Dict\n", "\n", + "from periodictable import elements\n", + "\n", "# Create a dictionary for exact weights of all elements\n", "element_weights = {el.symbol: el.mass for el in elements if el.mass is not None}\n", "\n", @@ -819,7 +821,7 @@ } ], "source": [ - "candidate_2[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=None)" + "candidate_2['Metabolite'].get_theoretical_precursor_mz(ion_type=None)" ] }, { @@ -838,8 +840,8 @@ "source": [ "print(\n", " get_exact_mass(\n", - " {\"C\": 1, \"H\": 1, \"N\": 1},\n", - " candidate_2[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=None),\n", + " {'C': 1, 'H': 1, 'N': 1},\n", + " candidate_2['Metabolite'].get_theoretical_precursor_mz(ion_type=None),\n", " )\n", ")" ] @@ -860,8 +862,8 @@ "source": [ "print(\n", " get_exact_mass(\n", - " {\"C\": 3, \"H\": 3, \"N\": 1},\n", - " candidate_2[\"Metabolite\"].get_theoretical_precursor_mz(ion_type=None),\n", + " {'C': 3, 'H': 3, 'N': 1},\n", + " candidate_2['Metabolite'].get_theoretical_precursor_mz(ion_type=None),\n", " )\n", ")" ] @@ -880,7 +882,7 @@ } ], "source": [ - "print(get_exact_mass({\"C\": 3, \"H\": 0, \"N\": 2}))" + "print(get_exact_mass({'C': 3, 'H': 0, 'N': 2}))" ] }, { @@ -906,7 +908,7 @@ } ], "source": [ - "df_rings[df_rings[\"NAME\"] == \"Purine\"][\"CE\"]" + "df_rings[df_rings['NAME'] == 'Purine']['CE']" ] }, { @@ -1019,18 +1021,18 @@ "metadata": {}, "outputs": [], "source": [ - "casmi16_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv\"\n", - "casmi22_path = f\"{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv\"\n", + "casmi16_path = f'{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv'\n", + "casmi22_path = f'{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv'\n", "\n", "df_cas = pd.read_csv(casmi16_path, index_col=[0], low_memory=False)\n", "df_cas22 = pd.read_csv(casmi22_path, index_col=[0], low_memory=False)\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"Candidates\"]\n", + "dict_columns = ['peaks', 'Candidates']\n", "for col in dict_columns:\n", " df_cas[col] = df_cas[col].apply(ast.literal_eval)\n", "\n", - "df_cas22[\"peaks\"] = df_cas22[\"peaks\"].apply(ast.literal_eval)" + "df_cas22['peaks'] = df_cas22['peaks'].apply(ast.literal_eval)" ] }, { @@ -1042,36 +1044,36 @@ "%%capture\n", "from fiora.MOL.collision_energy import NCE_to_eV\n", "\n", - "df_cas[\"RETENTIONTIME\"] = df_cas[\"RTINSECONDS\"] / 60.0\n", - "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cas['RETENTIONTIME'] = df_cas['RTINSECONDS'] / 60.0\n", + "df_cas['Metabolite'] = df_cas['SMILES'].apply(Metabolite)\n", + "df_cas['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas[\"Metabolite\"].apply(\n", + "df_cas['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "df_cas['CE'] = 20.0 # actually stepped 20/35/50\n", + "df_cas['Instrument_type'] = 'HCD' # CHECK if correct Orbitrap\n", "\n", "metadata_key_map16 = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'PRECURSOR_MZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", "}\n", "\n", - "df_cas[\"summary\"] = df_cas.apply(\n", + "df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", ")\n", "df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder), axis=1\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder), axis=1\n", ")\n", "\n", "# Fragmentation\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cas['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=100 * PPM\n", " ),\n", " axis=1,\n", ") # Optional: use mz_cut instead\n", @@ -1080,37 +1082,37 @@ "# CASMI 22\n", "#\n", "\n", - "df_cas22[\"Metabolite\"] = df_cas22[\"SMILES\"].apply(Metabolite)\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cas22['Metabolite'] = df_cas22['SMILES'].apply(Metabolite)\n", + "df_cas22['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas22[\"Metabolite\"].apply(\n", + "df_cas22['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cas22[\"CE\"] = df_cas22.apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1\n", + "df_cas22['CE'] = df_cas22.apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['precursor_mz']), axis=1\n", ")\n", "\n", "metadata_key_map22 = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"precursor_mz\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"ChallengeRT\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'precursor_mz',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'ChallengeRT',\n", "}\n", "\n", - "df_cas22[\"summary\"] = df_cas22.apply(\n", + "df_cas22['summary'] = df_cas22.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map22.items()}, axis=1\n", ")\n", "df_cas22.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", ")\n", "\n", "# Fragmentation\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cas22['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_cas22.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=100 * PPM\n", " ),\n", " axis=1,\n", ") # Optional: use mz_cut instead\n", @@ -1140,16 +1142,17 @@ } ], "source": [ - "from fiora.GNN.Trainer import Trainer\n", "import torch_geometric as geom\n", "\n", + "from fiora.GNN.Trainer import Trainer\n", + "\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", - " dev = \"cuda:1\"\n", + " dev = 'cuda:1'\n", "else:\n", - " dev = \"cpu\"\n", + " dev = 'cpu'\n", "\n", - "print(f\"Running on device: {dev}\")" + "print(f'Running on device: {dev}')" ] }, { @@ -1178,10 +1181,10 @@ } ], "source": [ - "print(df.groupby(\"dataset\")[\"group_id\"].unique().apply(len))\n", + "print(df.groupby('dataset')['group_id'].unique().apply(len))\n", "\n", - "df_test = df[df[\"dataset\"] == \"test\"]\n", - "df_train = df[df[\"dataset\"].isin([\"training\", \"validation\"])]" + "df_test = df[df['dataset'] == 'test']\n", + "df_train = df[df['dataset'].isin(['training', 'validation'])]" ] }, { @@ -1198,8 +1201,8 @@ } ], "source": [ - "geo_data = df_train[\"Metabolite\"].apply(lambda x: x.as_geometric_data().to(dev)).values\n", - "print(f\"Prepared training/validation with {len(geo_data)} data points\")" + "geo_data = df_train['Metabolite'].apply(lambda x: x.as_geometric_data().to(dev)).values\n", + "print(f'Prepared training/validation with {len(geo_data)} data points')" ] }, { @@ -1217,38 +1220,38 @@ "outputs": [], "source": [ "model_params = {\n", - " \"param_tag\": \"default\",\n", - " \"gnn_type\": \"RGCNConv\",\n", - " \"depth\": 6,\n", - " \"hidden_dimension\": 300,\n", - " \"dense_layers\": 2,\n", - " \"embedding_aggregation\": \"concat\",\n", - " \"embedding_dimension\": 300,\n", - " \"input_dropout\": 0.2,\n", - " \"latent_dropout\": 0.1,\n", - " \"node_feature_layout\": node_encoder.feature_numbers,\n", - " \"edge_feature_layout\": bond_encoder.feature_numbers,\n", - " \"static_feature_dimension\": geo_data[0][\"static_edge_features\"].shape[1],\n", - " \"static_rt_feature_dimension\": geo_data[0][\"static_rt_features\"].shape[1],\n", - " \"output_dimension\": len(DEFAULT_MODES) * 2, # per edge\n", + " 'param_tag': 'default',\n", + " 'gnn_type': 'RGCNConv',\n", + " 'depth': 6,\n", + " 'hidden_dimension': 300,\n", + " 'dense_layers': 2,\n", + " 'embedding_aggregation': 'concat',\n", + " 'embedding_dimension': 300,\n", + " 'input_dropout': 0.2,\n", + " 'latent_dropout': 0.1,\n", + " 'node_feature_layout': node_encoder.feature_numbers,\n", + " 'edge_feature_layout': bond_encoder.feature_numbers,\n", + " 'static_feature_dimension': geo_data[0]['static_edge_features'].shape[1],\n", + " 'static_rt_feature_dimension': geo_data[0]['static_rt_features'].shape[1],\n", + " 'output_dimension': len(DEFAULT_MODES) * 2, # per edge\n", " # Keep track of encoded features\n", - " \"atom_features\": node_encoder.feature_list,\n", - " \"atom_features\": bond_encoder.feature_list,\n", - " \"setup_features\": setup_encoder.feature_list,\n", - " \"setup_features_categorical_set\": setup_encoder.categorical_sets,\n", - " \"rt_features\": rt_encoder.feature_list,\n", + " 'atom_features': node_encoder.feature_list,\n", + " 'atom_features': bond_encoder.feature_list,\n", + " 'setup_features': setup_encoder.feature_list,\n", + " 'setup_features_categorical_set': setup_encoder.categorical_sets,\n", + " 'rt_features': rt_encoder.feature_list,\n", " # Set default flags (May be overwritten below)\n", - " \"rt_supported\": False,\n", - " \"ccs_supported\": False,\n", - " \"version\": \"x.x.x\",\n", + " 'rt_supported': False,\n", + " 'ccs_supported': False,\n", + " 'version': 'x.x.x',\n", "}\n", "training_params = {\n", - " \"epochs\": 200 if not test_run else 10,\n", - " \"batch_size\": 256,\n", + " 'epochs': 200 if not test_run else 10,\n", + " 'batch_size': 256,\n", " #'train_val_split': 0.90,\n", - " \"learning_rate\": 0.0004, # 0.00001 currently for wMAE # Default for wMSE is 0.0004, #0.001,\n", - " \"with_RT\": False, # Turn off RT/CCS for initial trainings round\n", - " \"with_CCS\": False,\n", + " 'learning_rate': 0.0004, # 0.00001 currently for wMAE # Default for wMSE is 0.0004, #0.001,\n", + " 'with_RT': False, # Turn off RT/CCS for initial trainings round\n", + " 'with_CCS': False,\n", "}" ] }, @@ -1267,20 +1270,21 @@ "outputs": [], "source": [ "from fiora.GNN.GNNModules import GNNCompiler\n", + "\n", "from fiora.GNN.Losses import (\n", - " WeightedMSELoss,\n", - " WeightedMSEMetric,\n", " WeightedMAELoss,\n", " WeightedMAEMetric,\n", + " WeightedMSELoss,\n", + " WeightedMSEMetric,\n", ")\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", "fiora = SimulationFramework(None, dev=dev, with_RT=True, with_CCS=True)\n", "# fiora = SimulationFramework(None, dev=dev, with_RT=training_params[\"with_RT\"], with_CCS=training_params[\"with_CCS\"])\n", - "np.seterr(invalid=\"ignore\")\n", - "tag = \"training\"\n", + "np.seterr(invalid='ignore')\n", + "tag = 'training'\n", "val_interval = 1\n", - "metric_dict = {\"mse\": WeightedMSEMetric} # WeightedMSEMetric\n", + "metric_dict = {'mse': WeightedMSEMetric} # WeightedMSEMetric\n", "loss_fn = WeightedMSELoss() # WeightedMSELoss()\n", "all_together = False\n", "\n", @@ -1296,15 +1300,15 @@ " else:\n", " model = GNNCompiler(model_params).to(dev)\n", "\n", - " y_label = \"compiled_probsSQRT\" # y_label = 'compiled_probsALL'\n", + " y_label = 'compiled_probsSQRT' # y_label = 'compiled_probsALL'\n", " optimizer = torch.optim.Adam(\n", - " model.parameters(), lr=training_params[\"learning_rate\"]\n", + " model.parameters(), lr=training_params['learning_rate']\n", " )\n", " if all_together:\n", " trainer = Trainer(\n", " geo_data,\n", " y_tag=y_label,\n", - " problem_type=\"regression\",\n", + " problem_type='regression',\n", " only_training=True,\n", " metric_dict=metric_dict,\n", " split_by_group=True,\n", @@ -1314,13 +1318,13 @@ " scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n", " else:\n", " train_keys, val_keys = (\n", - " df[df[\"dataset\"] == \"training\"][\"group_id\"].unique(),\n", - " df[df[\"dataset\"] == \"validation\"][\"group_id\"].unique(),\n", + " df[df['dataset'] == 'training']['group_id'].unique(),\n", + " df[df['dataset'] == 'validation']['group_id'].unique(),\n", " )\n", " trainer = Trainer(\n", " geo_data,\n", " y_tag=y_label,\n", - " problem_type=\"regression\",\n", + " problem_type='regression',\n", " train_keys=train_keys,\n", " val_keys=val_keys,\n", " metric_dict=metric_dict,\n", @@ -1329,7 +1333,7 @@ " device=dev,\n", " )\n", " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", - " optimizer, patience=8, factor=0.5, mode=\"min\", verbose=True\n", + " optimizer, patience=8, factor=0.5, mode='min', verbose=True\n", " )\n", "\n", " checkpoints = trainer.train(\n", @@ -1337,11 +1341,11 @@ " optimizer,\n", " loss_fn,\n", " scheduler=scheduler,\n", - " batch_size=training_params[\"batch_size\"],\n", - " epochs=training_params[\"epochs\"],\n", + " batch_size=training_params['batch_size'],\n", + " epochs=training_params['epochs'],\n", " val_every_n_epochs=1,\n", - " with_CCS=training_params[\"with_CCS\"],\n", - " with_RT=training_params[\"with_RT\"],\n", + " with_CCS=training_params['with_CCS'],\n", + " with_RT=training_params['with_RT'],\n", " masked_validation=False,\n", " tag=tag,\n", " ) # , mask_name=\"compiled_validation_maskALL\")\n", @@ -1353,7 +1357,7 @@ " return fiora.simulate_all(DF, model)\n", "\n", "\n", - "def test_model(model, DF, score=\"spectral_sqrt_cosine\", return_df=False):\n", + "def test_model(model, DF, score='spectral_sqrt_cosine', return_df=False):\n", " dft = simulate_all(model, DF)\n", "\n", " if return_df:\n", @@ -1374,109 +1378,108 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MOL.collision_energy import NCE_to_eV\n", + "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "from fiora.MS.spectral_scores import (\n", + " reweighted_dot,\n", " spectral_cosine,\n", " spectral_reflection_cosine,\n", - " reweighted_dot,\n", ")\n", - "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", "\n", - "def test_cas16(model, df_cas=df_cas, score=\"merged_sqrt_cosine\", return_df=False):\n", + "def test_cas16(model, df_cas=df_cas, score='merged_sqrt_cosine', return_df=False):\n", "\n", - " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 20.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step1_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_20')\n", "\n", - " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 35.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step2_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_35')\n", "\n", - " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 50.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step3_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_50')\n", "\n", - " df_cas[\"avg_CE\"] = (\n", - " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " df_cas['avg_CE'] = (\n", + " df_cas['step1_CE'] + df_cas['step2_CE'] + df_cas['step3_CE']\n", " ) / 3\n", "\n", - " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " df_cas['merged_peaks'] = df_cas.apply(\n", " lambda x: merge_annotated_spectrum(\n", - " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", - " x[\"sim_peaks_50\"],\n", + " merge_annotated_spectrum(x['sim_peaks_20'], x['sim_peaks_35']),\n", + " x['sim_peaks_50'],\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['merged_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " df_cas['merged_sqrt_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks'], transform=np.sqrt),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_sqrt_cosine_wo_prec\"] = df_cas.apply(\n", + " df_cas['merged_sqrt_cosine_wo_prec'] = df_cas.apply(\n", " lambda x: spectral_cosine(\n", - " x[\"peaks\"],\n", - " x[\"merged_peaks\"],\n", + " x['peaks'],\n", + " x['merged_peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " x[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=x['Metabolite'].get_theoretical_precursor_mz(\n", + " x['Metabolite'].metadata['precursor_mode']\n", " ),\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " df_cas['merged_refl_cosine'] = df_cas.apply(\n", " lambda x: spectral_reflection_cosine(\n", - " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " x['peaks'], x['merged_peaks'], transform=np.sqrt\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_steins\"] = df_cas.apply(\n", - " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['merged_steins'] = df_cas.apply(\n", + " lambda x: reweighted_dot(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", - " \"merged_sqrt_cosine\"\n", + " df_cas['spectral_sqrt_cosine'] = df_cas[\n", + " 'merged_sqrt_cosine'\n", " ] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine_wo_prec\"] = df_cas[\n", - " \"merged_sqrt_cosine_wo_prec\"\n", + " df_cas['spectral_sqrt_cosine_wo_prec'] = df_cas[\n", + " 'merged_sqrt_cosine_wo_prec'\n", " ] # just remember it is merged\n", "\n", - " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - " df_cas[\"RT_pred\"] = df_cas[\"RT_pred_35\"]\n", - " df_cas[\"RT_dif\"] = df_cas[\"RT_dif_35\"]\n", - " df_cas[\"CCS_pred\"] = df_cas[\"CCS_pred_35\"]\n", - " df_cas[\"library\"] = \"CASMI-16\"\n", + " df_cas['coverage'] = df_cas['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", + " df_cas['RT_pred'] = df_cas['RT_pred_35']\n", + " df_cas['RT_dif'] = df_cas['RT_dif_35']\n", + " df_cas['CCS_pred'] = df_cas['CCS_pred_35']\n", + " df_cas['library'] = 'CASMI-16'\n", "\n", " if return_df:\n", " return df_cas\n", @@ -1511,12 +1514,12 @@ ], "source": [ "def add_ce(d):\n", - " if \"phospho\" in d[\"name\"]:\n", - " d.update({\"ce_steps\": [20, 30, 40]})\n", + " if 'phospho' in d['name']:\n", + " d.update({'ce_steps': [20, 30, 40]})\n", " return d\n", "\n", "\n", - "df_train[\"summary\"] = df_train[\"summary\"].apply(add_ce)" + "df_train['summary'] = df_train['summary'].apply(add_ce)" ] }, { @@ -1533,7 +1536,7 @@ } ], "source": [ - "print(df_train.iloc[0][\"summary\"])" + "print(df_train.iloc[0]['summary'])" ] }, { @@ -1551,11 +1554,11 @@ ], "source": [ "df_train.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], setup_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder),\n", " axis=1,\n", ")\n", - "geo_data = df_train[\"Metabolite\"].apply(lambda x: x.as_geometric_data().to(dev)).values\n", - "print(f\"Prepared training/validation with {len(geo_data)} data points\")" + "geo_data = df_train['Metabolite'].apply(lambda x: x.as_geometric_data().to(dev)).values\n", + "print(f'Prepared training/validation with {len(geo_data)} data points')" ] }, { @@ -1577,7 +1580,7 @@ " doubled_batch[key] = torch.cat([value, value], dim=0)\n", "\n", " # Update batch index to reflect duplication in order\n", - " doubled_batch[\"batch\"] = torch.repeat_interleave(batch[\"batch\"], 2)\n", + " doubled_batch['batch'] = torch.repeat_interleave(batch['batch'], 2)\n", "\n", " yield doubled_batch" ] @@ -1603,7 +1606,7 @@ "loader_base = geom_loader.DataLoader\n", "dataloader = loader_base(geo_data, batch_size=5, num_workers=0, shuffle=False)\n", "for id, batch in enumerate(dataloader):\n", - " print(batch[\"ce_steps\"])\n", + " print(batch['ce_steps'])\n", " break" ] }, @@ -1625,7 +1628,7 @@ } ], "source": [ - "raise KeyboardInterrupt(\"Stop before starting a new training run!\")" + "raise KeyboardInterrupt('Stop before starting a new training run!')" ] }, { @@ -1647,9 +1650,9 @@ ], "source": [ "for id, batch in enumerate(dataloader):\n", - " print(batch[\"edge_index\"][0, :][:10])\n", - " print(torch.repeat_interleave(batch[\"edge_index\"][0, :], 5)[:50])\n", - " print(batch[\"batch\"][torch.repeat_interleave(batch[\"edge_index\"][0, :], 5)])\n", + " print(batch['edge_index'][0, :][:10])\n", + " print(torch.repeat_interleave(batch['edge_index'][0, :], 5)[:50])\n", + " print(batch['batch'][torch.repeat_interleave(batch['edge_index'][0, :], 5)])\n", " break\n", " # Feed forward\n", " model.train()\n", @@ -1657,39 +1660,39 @@ " y_pred = model(batch, with_RT=with_RT, with_CCS=with_CCS)\n", " kwargs = {}\n", " if with_weights:\n", - " kwargs = {\"weight\": batch[\"weight_tensor\"]}\n", + " kwargs = {'weight': batch['weight_tensor']}\n", "\n", - " loss = loss_fn(y_pred[\"fragment_probs\"], batch[self.y_tag], **kwargs) # with logits\n", + " loss = loss_fn(y_pred['fragment_probs'], batch[self.y_tag], **kwargs) # with logits\n", " if not rt_metric:\n", - " metrics(y_pred[\"fragment_probs\"], batch[self.y_tag], **kwargs) # call update\n", + " metrics(y_pred['fragment_probs'], batch[self.y_tag], **kwargs) # call update\n", "\n", " # Add RT and CCS to loss\n", " if with_RT:\n", " if with_weights:\n", - " kwargs[\"weight\"] = batch[\"weight\"][batch[\"retention_mask\"]]\n", + " kwargs['weight'] = batch['weight'][batch['retention_mask']]\n", " loss_rt = loss_fn(\n", - " y_pred[\"rt\"][batch[\"retention_mask\"]],\n", - " batch[\"retention_time\"][batch[\"retention_mask\"]],\n", + " y_pred['rt'][batch['retention_mask']],\n", + " batch['retention_time'][batch['retention_mask']],\n", " **kwargs,\n", " )\n", " loss = loss + loss_rt\n", "\n", " if with_CCS:\n", " if with_weights:\n", - " kwargs[\"weight\"] = batch[\"weight\"][batch[\"ccs_mask\"]]\n", + " kwargs['weight'] = batch['weight'][batch['ccs_mask']]\n", " loss_ccs = loss_fn(\n", - " y_pred[\"ccs\"][batch[\"ccs_mask\"]], batch[\"ccs\"][batch[\"ccs_mask\"]], **kwargs\n", + " y_pred['ccs'][batch['ccs_mask']], batch['ccs'][batch['ccs_mask']], **kwargs\n", " )\n", " loss = loss + loss_ccs\n", "\n", " if rt_metric:\n", " metrics(\n", - " y_pred[\"rt\"][batch[\"retention_mask\"]],\n", - " batch[\"retention_time\"][batch[\"retention_mask\"]],\n", + " y_pred['rt'][batch['retention_mask']],\n", + " batch['retention_time'][batch['retention_mask']],\n", " **kwargs,\n", " ) # call update\n", " metrics(\n", - " y_pred[\"ccs\"][batch[\"ccs_mask\"]], batch[\"ccs\"][batch[\"ccs_mask\"]], **kwargs\n", + " y_pred['ccs'][batch['ccs_mask']], batch['ccs'][batch['ccs_mask']], **kwargs\n", " ) # call update\n", "\n", " # Backpropagate\n", @@ -1963,7 +1966,7 @@ } ], "source": [ - "print(f\"Training model\")\n", + "print('Training model')\n", "model, checkpoints = train_new_model() # continue_with_model=model)" ] }, @@ -1982,10 +1985,8 @@ } ], "source": [ - "import copy\n", - "\n", "print(checkpoints)\n", - "print(np.sqrt(checkpoints[\"val_loss\"]))\n", + "print(np.sqrt(checkpoints['val_loss']))\n", "model_end = copy.deepcopy(model)" ] }, @@ -2085,34 +2086,34 @@ ], "source": [ "model = model_end # GNNCompiler.load(checkpoints[\"file\"]).to(dev)\n", - "score = \"spectral_sqrt_cosine\"\n", + "score = 'spectral_sqrt_cosine'\n", "\n", "val_results = test_model(\n", - " model, df_train[df_train[\"dataset\"] == \"validation\"], score=score\n", + " model, df_train[df_train['dataset'] == 'validation'], score=score\n", ")\n", "test_results = test_model(model, df_test, score=score)\n", "casmi16_results = test_cas16(model, score=score)\n", - "casmi16_p = test_cas16(model, df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"], score=score)\n", - "casmi16_n = test_cas16(model, df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"], score=score)\n", + "casmi16_p = test_cas16(model, df_cas[df_cas['Precursor_type'] == '[M+H]+'], score=score)\n", + "casmi16_n = test_cas16(model, df_cas[df_cas['Precursor_type'] == '[M-H]-'], score=score)\n", "casmi22_results = test_model(model, df_cas22, score=score)\n", "casmi22_p = test_model(\n", - " model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"], score=score\n", + " model, df_cas22[df_cas22['Precursor_type'] == '[M+H]+'], score=score\n", ")\n", "casmi22_n = test_model(\n", - " model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"], score=score\n", + " model, df_cas22[df_cas22['Precursor_type'] == '[M-H]-'], score=score\n", ")\n", "\n", "results = [\n", " {\n", - " \"model\": model,\n", - " \"validation\": val_results,\n", - " \"test\": test_results,\n", - " \"casmi16\": casmi16_results,\n", - " \"casmi22\": casmi22_results,\n", - " \"casmi16+\": casmi16_p,\n", - " \"casmi16-\": casmi16_n,\n", - " \"casmi22+\": casmi22_p,\n", - " \"casmi22-\": casmi22_n,\n", + " 'model': model,\n", + " 'validation': val_results,\n", + " 'test': test_results,\n", + " 'casmi16': casmi16_results,\n", + " 'casmi22': casmi22_results,\n", + " 'casmi16+': casmi16_p,\n", + " 'casmi16-': casmi16_n,\n", + " 'casmi22+': casmi22_p,\n", + " 'casmi22-': casmi22_n,\n", " }\n", "]" ] @@ -2212,33 +2213,33 @@ } ], "source": [ - "score = \"spectral_sqrt_cosine_wo_prec\"\n", + "score = 'spectral_sqrt_cosine_wo_prec'\n", "val_results = test_model(\n", - " model, df_train[df_train[\"dataset\"] == \"validation\"], score=score\n", + " model, df_train[df_train['dataset'] == 'validation'], score=score\n", ")\n", "test_results = test_model(model, df_test, score=score)\n", "casmi16_results = test_cas16(model, score=score)\n", - "casmi16_p = test_cas16(model, df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"], score=score)\n", - "casmi16_n = test_cas16(model, df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"], score=score)\n", + "casmi16_p = test_cas16(model, df_cas[df_cas['Precursor_type'] == '[M+H]+'], score=score)\n", + "casmi16_n = test_cas16(model, df_cas[df_cas['Precursor_type'] == '[M-H]-'], score=score)\n", "casmi22_results = test_model(model, df_cas22, score=score)\n", "casmi22_p = test_model(\n", - " model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"], score=score\n", + " model, df_cas22[df_cas22['Precursor_type'] == '[M+H]+'], score=score\n", ")\n", "casmi22_n = test_model(\n", - " model, df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"], score=score\n", + " model, df_cas22[df_cas22['Precursor_type'] == '[M-H]-'], score=score\n", ")\n", "\n", "results_wop = [\n", " {\n", - " \"model\": model,\n", - " \"validation\": val_results,\n", - " \"test\": test_results,\n", - " \"casmi16\": casmi16_results,\n", - " \"casmi22\": casmi22_results,\n", - " \"casmi16+\": casmi16_p,\n", - " \"casmi16-\": casmi16_n,\n", - " \"casmi22+\": casmi22_p,\n", - " \"casmi22-\": casmi22_n,\n", + " 'model': model,\n", + " 'validation': val_results,\n", + " 'test': test_results,\n", + " 'casmi16': casmi16_results,\n", + " 'casmi22': casmi22_results,\n", + " 'casmi16+': casmi16_p,\n", + " 'casmi16-': casmi16_n,\n", + " 'casmi22+': casmi22_p,\n", + " 'casmi22-': casmi22_n,\n", " }\n", "]" ] @@ -2488,7 +2489,7 @@ } ], "source": [ - "raise KeyboardInterrupt(\"Halt! Make sure you wish to save/overwrite model files\")" + "raise KeyboardInterrupt('Halt! Make sure you wish to save/overwrite model files')" ] }, { @@ -2505,10 +2506,10 @@ } ], "source": [ - "depth = model_params[\"depth\"]\n", - "MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_OS_depth{depth}_Aug24_sqrt_XXX.pt\"\n", + "depth = model_params['depth']\n", + "MODEL_PATH = f'{home}/data/metabolites/pretrained_models/v0.0.1_OS_depth{depth}_Aug24_sqrt_XXX.pt'\n", "model.save(MODEL_PATH)\n", - "print(f\"Saved to {MODEL_PATH}\")" + "print(f'Saved to {MODEL_PATH}')" ] }, { @@ -2544,9 +2545,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from fiora.GNN.GNNModules import GNNCompiler" - ] + "source": [] }, { "cell_type": "code", @@ -2565,7 +2564,7 @@ } ], "source": [ - "df_test[\"Metabolite\"].iloc[0].as_geometric_data()" + "df_test['Metabolite'].iloc[0].as_geometric_data()" ] }, { @@ -2574,7 +2573,7 @@ "metadata": {}, "outputs": [], "source": [ - "dev = \"cuda:0\"\n", + "dev = 'cuda:0'\n", "mymy = GNNCompiler.load(\n", " MODEL_PATH\n", ") # f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_2.pt\"\n", @@ -2631,10 +2630,10 @@ "source": [ "import json\n", "\n", - "with open(MODEL_PATH.replace(\".pt\", \"_params.json\"), \"r\") as fp:\n", + "with open(MODEL_PATH.replace('.pt', '_params.json'), 'r') as fp:\n", " p = json.load(fp)\n", "hh = GNNCompiler(p)\n", - "hh.load_state_dict(torch.load(MODEL_PATH.replace(\".pt\", \"_state.pt\")))\n", + "hh.load_state_dict(torch.load(MODEL_PATH.replace('.pt', '_state.pt')))\n", "hh.eval()\n", "hh = hh.to(dev)" ] @@ -2677,7 +2676,7 @@ } ], "source": [ - "raise KeyboardInterrupt(\"TODO\")" + "raise KeyboardInterrupt('TODO')" ] }, { @@ -2710,25 +2709,25 @@ "import os\n", "\n", "save_df = False\n", - "cfm_directory = f\"{home}/data/metabolites/cfm-id/\"\n", - "name = \"test_split_negative_solutions_cfm.txt\"\n", - "df_cfm = df_test[[\"group_id\", \"SMILES\", \"Precursor_type\"]]\n", - "df_n = df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"].drop_duplicates(\n", - " subset=\"group_id\", keep=\"first\"\n", + "cfm_directory = f'{home}/data/metabolites/cfm-id/'\n", + "name = 'test_split_negative_solutions_cfm.txt'\n", + "df_cfm = df_test[['group_id', 'SMILES', 'Precursor_type']]\n", + "df_n = df_cfm[df_cfm['Precursor_type'] == '[M-H]-'].drop_duplicates(\n", + " subset='group_id', keep='first'\n", ")\n", - "df_p = df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"].drop_duplicates(\n", - " subset=\"group_id\", keep=\"first\"\n", + "df_p = df_cfm[df_cfm['Precursor_type'] == '[M+H]+'].drop_duplicates(\n", + " subset='group_id', keep='first'\n", ")\n", "\n", "print(df_n.head())\n", "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", - " df_n[[\"group_id\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")\n", + " df_n[['group_id', 'SMILES']].to_csv(file, index=False, header=False, sep=' ')\n", "\n", - " name = name.replace(\"negative\", \"positive\")\n", + " name = name.replace('negative', 'positive')\n", " file = os.path.join(cfm_directory, name)\n", - " df_p[[\"group_id\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")" + " df_p[['group_id', 'SMILES']].to_csv(file, index=False, header=False, sep=' ')" ] }, { @@ -2745,7 +2744,9 @@ "outputs": [], "source": [ "import json\n", + "\n", "from fiora.GNN.GNNModules import GNNCompiler\n", + "\n", "from fiora.MS.SimulationFramework import SimulationFramework" ] }, @@ -2795,15 +2796,15 @@ "source": [ "# Load best model\n", "\n", - "dev = \"cuda:0\"\n", + "dev = 'cuda:0'\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/pre_package/v0.0.1_merged_depth6_Jan24.pt\"\n", - "MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_depth6_Aug24_sqrt.pt\" # New sqrt model (improved)\n", + "MODEL_PATH = f'{home}/data/metabolites/pretrained_models/v0.0.1_merged_depth6_Aug24_sqrt.pt' # New sqrt model (improved)\n", "\n", "try:\n", " model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n", - " print(\"Model loaded from state dict without errors.\")\n", + " print('Model loaded from state dict without errors.')\n", "except:\n", - " raise NameError(\"Error: Failed loading from state dict.\")\n", + " raise NameError('Error: Failed loading from state dict.')\n", "\n", "\n", "model.eval()\n", @@ -2819,11 +2820,11 @@ "outputs": [], "source": [ "spectral_modules = [\n", - " \"node_embedding\",\n", - " \"edge_embedding\",\n", - " \"GNN_module\",\n", - " \"edge_module\",\n", - " \"precursor_module\",\n", + " 'node_embedding',\n", + " 'edge_embedding',\n", + " 'GNN_module',\n", + " 'edge_module',\n", + " 'precursor_module',\n", "]\n", "\n", "\n", @@ -2858,7 +2859,7 @@ "source": [ "for name, param in model.named_parameters():\n", " if param.requires_grad:\n", - " print(f\"{name}: requires gradients\")\n", + " print(f'{name}: requires gradients')\n", " # else:\n", " # print(f\"{name}: does not require gradients (frozen)\")" ] @@ -2882,7 +2883,7 @@ } ], "source": [ - "df_train[\"RTorCCS\"] = ~(df_train[\"RETENTIONTIME\"].isna() & df_train[\"CCS\"].isna())" + "df_train['RTorCCS'] = ~(df_train['RETENTIONTIME'].isna() & df_train['CCS'].isna())" ] }, { @@ -2901,26 +2902,26 @@ } ], "source": [ - "rt_index = df_train.drop_duplicates(\"group_id\", keep=\"first\")[\"RTorCCS\"]\n", + "rt_index = df_train.drop_duplicates('group_id', keep='first')['RTorCCS']\n", "print(\n", - " \"RT: \",\n", + " 'RT: ',\n", " sum(\n", - " ~df_train.drop_duplicates(\"group_id\", keep=\"first\")[rt_index][\n", - " \"RETENTIONTIME\"\n", + " ~df_train.drop_duplicates('group_id', keep='first')[rt_index][\n", + " 'RETENTIONTIME'\n", " ].isna()\n", " ),\n", ")\n", "print(\n", - " \"CCS: \",\n", - " sum(~df_train.drop_duplicates(\"group_id\", keep=\"first\")[rt_index][\"CCS\"].isna()),\n", + " 'CCS: ',\n", + " sum(~df_train.drop_duplicates('group_id', keep='first')[rt_index]['CCS'].isna()),\n", ")\n", "\n", "geo_data = (\n", - " df_train.drop_duplicates(\"group_id\", keep=\"first\")[rt_index][\"Metabolite\"]\n", + " df_train.drop_duplicates('group_id', keep='first')[rt_index]['Metabolite']\n", " .apply(lambda x: x.as_geometric_data().to(dev))\n", " .values\n", ")\n", - "print(f\"Prepared training/validation with {len(geo_data)} data points\")" + "print(f'Prepared training/validation with {len(geo_data)} data points')" ] }, { @@ -2935,16 +2936,16 @@ "\n", "\n", "def train_rt_model(rt_lr=rt_lr, rt_batch=rt_batch, rt_epochs=rt_epochs):\n", - " y_label = \"compiled_probsALL\"\n", + " y_label = 'compiled_probsALL'\n", " optimizer = torch.optim.Adam(model.parameters(), lr=rt_lr)\n", " train_keys, val_keys = (\n", - " df[df[\"dataset\"] == \"training\"][\"group_id\"].unique(),\n", - " df[df[\"dataset\"] == \"validation\"][\"group_id\"].unique(),\n", + " df[df['dataset'] == 'training']['group_id'].unique(),\n", + " df[df['dataset'] == 'validation']['group_id'].unique(),\n", " )\n", " trainer = Trainer(\n", " geo_data,\n", " y_tag=y_label,\n", - " problem_type=\"regression\",\n", + " problem_type='regression',\n", " train_keys=train_keys,\n", " val_keys=val_keys,\n", " metric_dict=None,\n", @@ -2953,7 +2954,7 @@ " device=dev,\n", " )\n", " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", - " optimizer, patience=8, factor=0.8, mode=\"min\", verbose=True\n", + " optimizer, patience=8, factor=0.8, mode='min', verbose=True\n", " )\n", " checkpoints = trainer.train(\n", " model,\n", @@ -3590,10 +3591,10 @@ } ], "source": [ - "print(cp, np.sqrt(cp[\"val_loss\"]))\n", - "model = GNNCompiler.load(cp[\"file\"])\n", + "print(cp, np.sqrt(cp['val_loss']))\n", + "model = GNNCompiler.load(cp['file'])\n", "model.to(dev)\n", - "print(\"Loaded best checkpoint\")" + "print('Loaded best checkpoint')" ] }, { @@ -3610,7 +3611,7 @@ } ], "source": [ - "NEW_MODEL_PATH = MODEL_PATH.replace(\".pt\", \"+CCS+RT_drop4.pt\")\n", + "NEW_MODEL_PATH = MODEL_PATH.replace('.pt', '+CCS+RT_drop4.pt')\n", "model = GNNCompiler.load(NEW_MODEL_PATH)\n", "model.to(dev)\n", "print(\"Loaded best 'new' model\")" @@ -3676,12 +3677,12 @@ "metadata": {}, "outputs": [], "source": [ - "raise KeyboardInterrupt(\"Halt! Make sure you wanna save the model\")\n", + "raise KeyboardInterrupt('Halt! Make sure you wanna save the model')\n", "# Adjust flags if neccessary\n", - "model.model_params[\"rt_supported\"] = True\n", - "model.model_params[\"ccs_supported\"] = True\n", + "model.model_params['rt_supported'] = True\n", + "model.model_params['ccs_supported'] = True\n", "\n", - "NEW_MODEL_PATH = MODEL_PATH.replace(\".pt\", \"+CCS+RT_dropX.pt\")\n", + "NEW_MODEL_PATH = MODEL_PATH.replace('.pt', '+CCS+RT_dropX.pt')\n", "model.save(NEW_MODEL_PATH)" ] }, @@ -3717,13 +3718,13 @@ ], "source": [ "val_df = test_model(\n", - " model, df_train[df_train[\"dataset\"] == \"validation\"], return_df=True\n", + " model, df_train[df_train['dataset'] == 'validation'], return_df=True\n", ")\n", "test_df = test_model(model, df_test, return_df=True)\n", "casmi16_df = test_cas16(model, return_df=True)\n", - "val_df[\"Dataset\"] = \"Val\"\n", - "casmi16_df[\"Dataset\"] = \"CASMI 22\"\n", - "test_df[\"Dataset\"] = \"Test split\"" + "val_df['Dataset'] = 'Val'\n", + "casmi16_df['Dataset'] = 'CASMI 22'\n", + "test_df['Dataset'] = 'Test split'" ] }, { @@ -3744,48 +3745,49 @@ ], "source": [ "import seaborn as sns\n", + "\n", "from fiora.visualization.define_colors import *\n", "\n", "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + " 2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True\n", ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "set_light_theme()\n", "\n", - "df_test_unique = test_df.drop_duplicates(subset=\"group_id\", keep=\"first\")\n", + "df_test_unique = test_df.drop_duplicates(subset='group_id', keep='first')\n", "\n", "\n", "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", " ax=axs[0],\n", " data=df_test_unique,\n", - " x=\"RETENTIONTIME\",\n", + " x='RETENTIONTIME',\n", " bw_adjust=0.25,\n", - " color=\"gray\",\n", + " color='gray',\n", " fill=True,\n", ") # , multiple=\"stack\") #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", - " ax=axs[0], data=df_test_unique, x=\"RETENTIONTIME\", bw_adjust=0.25, color=\"gray\"\n", + " ax=axs[0], data=df_test_unique, x='RETENTIONTIME', bw_adjust=0.25, color='gray'\n", ") # , multiple=\"stack\") #hue=\"Precursor_type\",\n", "\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "\n", "\n", "sns.scatterplot(\n", - " ax=axs[1], data=df_test_unique, x=\"RETENTIONTIME\", y=\"RT_pred\", color=\"gray\"\n", + " ax=axs[1], data=df_test_unique, x='RETENTIONTIME', y='RT_pred', color='gray'\n", ") # , hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", - "axs[1].set_ylim([0, df_test_unique[\"RETENTIONTIME\"].max() + 1])\n", - "axs[1].set_xlim([0, df_test_unique[\"RETENTIONTIME\"].max() + 1])\n", - "axs[1].set_ylabel(\"Predicted retention time\")\n", - "axs[1].set_xlabel(\"Observed retention time\")\n", + "axs[1].set_ylim([0, df_test_unique['RETENTIONTIME'].max() + 1])\n", + "axs[1].set_xlim([0, df_test_unique['RETENTIONTIME'].max() + 1])\n", + "axs[1].set_ylabel('Predicted retention time')\n", + "axs[1].set_xlabel('Observed retention time')\n", "line = [0, 100]\n", - "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", + "sns.lineplot(ax=axs[1], x=line, y=line, color='black')\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[x + 20 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[x + 20 / 60.0 for x in line], color='black', linestyle='--'\n", ")\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[x - 20 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[x - 20 / 60.0 for x in line], color='black', linestyle='--'\n", ")\n", "plt.show()" ] @@ -3821,60 +3823,60 @@ "source": [ "# TODO NEXT UP!!\n", "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + " 2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True\n", ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", - "df_test_unique = test_df.drop_duplicates(subset=\"group_id\", keep=\"first\")\n", + "df_test_unique = test_df.drop_duplicates(subset='group_id', keep='first')\n", "CCS = pd.concat(\n", " [\n", - " df_test_unique[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", - " casmi16_df[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", + " df_test_unique[['CCS', 'CCS_pred', 'Dataset']],\n", + " casmi16_df[['CCS', 'CCS_pred', 'Dataset']],\n", " ],\n", " ignore_index=True,\n", - ") #\n", + ")\n", "\n", "\n", "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", " ax=axs[0],\n", " data=CCS,\n", - " x=\"CCS\",\n", + " x='CCS',\n", " bw_adjust=0.35,\n", - " color=\"black\",\n", - " multiple=\"stack\",\n", - " hue=\"Dataset\",\n", + " color='black',\n", + " multiple='stack',\n", + " hue='Dataset',\n", " palette=tri_palette,\n", - " edgecolor=\"white\",\n", + " edgecolor='white',\n", ") # hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "\n", "\n", "sns.scatterplot(\n", " ax=axs[1],\n", " data=CCS,\n", - " x=\"CCS\",\n", - " y=\"CCS_pred\",\n", - " hue=\"Dataset\",\n", + " x='CCS',\n", + " y='CCS_pred',\n", + " hue='Dataset',\n", " palette=tri_palette,\n", - " style=\"Dataset\",\n", - " markers=[\".\", \"X\", \"*\"],\n", - " color=\"gray\",\n", + " style='Dataset',\n", + " markers=['.', 'X', '*'],\n", + " color='gray',\n", " s=25,\n", " linewidth=0.0,\n", ") # , color=\"blue\", edgecolor=\"blue\")#,\n", - "axs[1].set_ylim([df_test_unique[\"CCS\"].min() - 10, df_test_unique[\"CCS\"].max() + 10])\n", - "axs[1].set_xlim([df_test_unique[\"CCS\"].min() - 10, df_test_unique[\"CCS\"].max() + 10])\n", - "axs[1].set_ylabel(\"Predicted CCS\")\n", - "axs[1].set_xlabel(\"Observed CCS\")\n", - "line = [df_test_unique[\"CCS\"].min() - 10, df_test_unique[\"CCS\"].max() + 10]\n", - "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", + "axs[1].set_ylim([df_test_unique['CCS'].min() - 10, df_test_unique['CCS'].max() + 10])\n", + "axs[1].set_xlim([df_test_unique['CCS'].min() - 10, df_test_unique['CCS'].max() + 10])\n", + "axs[1].set_ylabel('Predicted CCS')\n", + "axs[1].set_xlabel('Observed CCS')\n", + "line = [df_test_unique['CCS'].min() - 10, df_test_unique['CCS'].max() + 10]\n", + "sns.lineplot(ax=axs[1], x=line, y=line, color='black')\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[1.1 * x for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[1.1 * x for x in line], color='black', linestyle='--'\n", ")\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[0.9 * x for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[0.9 * x for x in line], color='black', linestyle='--'\n", ")\n", "plt.show()" ] @@ -3910,77 +3912,77 @@ } ], "source": [ - "print(\"TEST SPLIT:\\n\")\n", - "print(\"Pearson Corr Coef:\")\n", + "print('TEST SPLIT:\\n')\n", + "print('Pearson Corr Coef:')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " np.corrcoef(\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_test_unique.dropna(subset=['CCS'])['CCS'],\n", + " df_test_unique.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " np.corrcoef(\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " df_test_unique.dropna(subset=['CCS'])['CCS'],\n", + " df_test_unique.dropna(subset=['CCS'])['PrecursorMZ'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "\n", "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", - " df_train.dropna(subset=[\"CCS\"])[\"PRECURSORMZ\"],\n", - " df_train.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_train.dropna(subset=['CCS'])['PRECURSORMZ'],\n", + " df_train.dropna(subset=['CCS'])['CCS'],\n", ")\n", - "print(\"R2\")\n", + "print('R2')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " r2_score(\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_test_unique.dropna(subset=['CCS'])['CCS'],\n", + " df_test_unique.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " ),\n", ")\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " r2_score(\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_test_unique.dropna(subset=['CCS'])['CCS'],\n", " intercept\n", - " + slope * df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " + slope * df_test_unique.dropna(subset=['CCS'])['PrecursorMZ'].dropna(),\n", " ),\n", ")\n", "\n", - "print(\"---------------\\n\\nCASMI-16:\\n\")\n", - "print(\"Pearson Corr Coef:\")\n", + "print('---------------\\n\\nCASMI-16:\\n')\n", + "print('Pearson Corr Coef:')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " np.corrcoef(\n", - " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " casmi16_df.dropna(subset=['CCS'])['CCS'],\n", + " casmi16_df.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " np.corrcoef(\n", - " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " casmi16_df.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " casmi16_df.dropna(subset=['CCS'])['CCS'],\n", + " casmi16_df.dropna(subset=['CCS'])['PRECURSOR_MZ'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", - "print(\"R2\")\n", + "print('R2')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " r2_score(\n", - " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " casmi16_df.dropna(subset=['CCS'])['CCS'],\n", + " casmi16_df.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " ),\n", ")\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " r2_score(\n", - " casmi16_df.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " intercept + slope * casmi16_df.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " casmi16_df.dropna(subset=['CCS'])['CCS'],\n", + " intercept + slope * casmi16_df.dropna(subset=['CCS'])['PRECURSOR_MZ'].dropna(),\n", " ),\n", ")" ] diff --git a/notebooks/test_model.ipynb b/notebooks/test_model.ipynb index cc5d3b4..e922bb1 100644 --- a/notebooks/test_model.ipynb +++ b/notebooks/test_model.ipynb @@ -28,8 +28,8 @@ ], "source": [ "import sys\n", - "import torch\n", "\n", + "import torch\n", "\n", "seed = 42\n", "# torch.set_default_dtype(torch.float64)\n", @@ -37,30 +37,31 @@ "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", "\n", - "import pandas as pd\n", - "import numpy as np\n", "import ast\n", "import copy\n", + "\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", "import seaborn as sns\n", "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "from fiora.MOL.constants import *\n", - "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree\n", - "import fiora.visualization.spectrum_visualizer as sv\n", - "\n", - "from sklearn.metrics import r2_score\n", + "home = expanduser('~')\n", "import scipy\n", "from rdkit import RDLogger\n", + "from sklearn.metrics import r2_score\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "import fiora.visualization.spectrum_visualizer as sv\n", + "from fiora.IO.LibraryLoader import LibraryLoader\n", + "from fiora.MOL.constants import *\n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", + "\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f\"Working with Python {sys.version}\")" + "print(f'Working with Python {sys.version}')" ] }, { @@ -82,7 +83,7 @@ "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/pre_package/v0.0.1_2_OS_depth{depth}_June24+CCS+RT.pt\" # OS model (first try)\n", "\n", "# NEW AND SHINY\n", - "MODEL_PATH = f\"../resources/models/fiora_OS_v1.0.0.pt\" # Release version\n", + "MODEL_PATH = '../resources/models/fiora_OS_v1.0.0.pt' # Release version\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v1.0.0_OS_depth10_Sep25_x4.pt\"\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_depth{depth}_Aug24_sqrt+CCS+RT_drop3.pt\" # New sqrt model (improved) | Note: drop3 uses dropout reduction while training RT, CCS\n", "# MODEL_PATH = f\"{home}/data/metabolites/pretrained_models/v0.0.1_OS_depth{depth}_Aug24_sqrt_4.pt\" # or Aug24_sqrt_x are new OS models\n", @@ -93,11 +94,10 @@ "from fiora.GNN.FioraModel import FioraModel\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", - "\n", "try:\n", " model = FioraModel.load_from_state_dict(MODEL_PATH)\n", "except:\n", - " raise NameError(\"Error: Failed loading from state dict.\")" + " raise NameError('Error: Failed loading from state dict.')" ] }, { @@ -115,9 +115,9 @@ ], "source": [ "has_m_plus = False\n", - "if \"setup_features_categorical_set\" in model.model_params.keys():\n", - " print(model.model_params[\"setup_features_categorical_set\"][\"precursor_mode\"])\n", - " if \"[M]+\" in model.model_params[\"setup_features_categorical_set\"][\"precursor_mode\"]:\n", + "if 'setup_features_categorical_set' in model.model_params.keys():\n", + " print(model.model_params['setup_features_categorical_set']['precursor_mode'])\n", + " if '[M]+' in model.model_params['setup_features_categorical_set']['precursor_mode']:\n", " has_m_plus = True" ] }, @@ -155,14 +155,14 @@ "\n", "def load_training_data():\n", " L = LibraryLoader()\n", - " df = L.load_from_csv(f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\")\n", + " df = L.load_from_csv(f'{home}/data/metabolites/preprocessed/datasplits_Jan24.csv')\n", " return df\n", "\n", "\n", "def load_msnlib():\n", " L = LibraryLoader()\n", " df = L.load_from_csv(\n", - " f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_Aug24_v3.csv\"\n", + " f'{home}/data/metabolites/preprocessed/datasplits_msnlib_Aug24_v3.csv'\n", " )\n", " return df\n", "\n", @@ -171,15 +171,15 @@ "df_msnlib = load_msnlib()\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"summary\"]\n", + "dict_columns = ['peaks', 'summary']\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", " df_msnlib[col] = df_msnlib[col].apply(\n", - " lambda x: ast.literal_eval(x.replace(\"nan\", \"None\"))\n", + " lambda x: ast.literal_eval(x.replace('nan', 'None'))\n", " )\n", " # df[col] = df[col].apply(ast.literal_eval)\n", "\n", - "df[\"group_id\"] = df[\"group_id\"].astype(int)" + "df['group_id'] = df['group_id'].astype(int)" ] }, { @@ -204,9 +204,9 @@ } ], "source": [ - "print(df.groupby(\"lib\")[\"group_id\"].unique().apply(len))\n", - "print(df[\"lib\"].value_counts())\n", - "print(len(df[\"group_id\"].unique()))" + "print(df.groupby('lib')['group_id'].unique().apply(len))\n", + "print(df['lib'].value_counts())\n", + "print(len(df['group_id'].unique()))" ] }, { @@ -228,14 +228,14 @@ } ], "source": [ - "print(df.groupby(\"dataset\")[\"group_id\"].unique().apply(len))\n", + "print(df.groupby('dataset')['group_id'].unique().apply(len))\n", "\n", - "print(\"Reducing data to test set.\")\n", - "df_train = df[df[\"dataset\"] != \"test\"]\n", - "df_test = df[df[\"dataset\"] == \"test\"]\n", + "print('Reducing data to test set.')\n", + "df_train = df[df['dataset'] != 'test']\n", + "df_test = df[df['dataset'] == 'test']\n", "\n", - "df_msnlib_train = df_msnlib[df_msnlib[\"dataset\"] != \"test\"]\n", - "df_msnlib_test = df_msnlib[df_msnlib[\"dataset\"] == \"test\"]" + "df_msnlib_train = df_msnlib[df_msnlib['dataset'] != 'test']\n", + "df_msnlib_test = df_msnlib[df_msnlib['dataset'] == 'test']" ] }, { @@ -244,76 +244,75 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MOL.Metabolite import Metabolite\n", "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", "from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder\n", - "\n", + "from fiora.MOL.Metabolite import Metabolite\n", "\n", "CE_upper_limit = 100.0\n", "weight_upper_limit = 1000.0\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "model_setup_feature_sets = None\n", - "if \"setup_features_categorical_set\" in model.model_params.keys():\n", - " model_setup_feature_sets = model.model_params[\"setup_features_categorical_set\"]\n", + "if 'setup_features_categorical_set' in model.model_params.keys():\n", + " model_setup_feature_sets = model.model_params['setup_features_categorical_set']\n", " # TODO Refactor this:\n", " for i, data in df_test.iterrows():\n", " if (\n", - " df_test.loc[i][\"summary\"][\"instrument\"]\n", - " not in model_setup_feature_sets[\"instrument\"]\n", + " df_test.loc[i]['summary']['instrument']\n", + " not in model_setup_feature_sets['instrument']\n", " ):\n", - " df_test.loc[i][\"summary\"][\"instrument\"] = \"HCD\"\n", + " df_test.loc[i]['summary']['instrument'] = 'HCD'\n", "covariate_encoder = CovariateFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", - " \"element_composition\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", + " 'element_composition',\n", " ],\n", " sets_overwrite=model_setup_feature_sets,\n", ")\n", "rt_encoder = CovariateFeatureEncoder(\n", - " feature_list=[\"molecular_weight\", \"precursor_mode\", \"instrument\"],\n", + " feature_list=['molecular_weight', 'precursor_mode', 'instrument'],\n", " sets_overwrite=model_setup_feature_sets,\n", ")\n", "\n", "\n", "def process_dataframes(df_train, df_test):\n", "\n", - " df_train[\"Metabolite\"] = df_train[\"SMILES\"].apply(\n", + " df_train['Metabolite'] = df_train['SMILES'].apply(\n", " Metabolite\n", " ) # TRAIN Metabolites are only tracked for tanimoto distance\n", - " df_test[\"Metabolite\"] = df_test[\"SMILES\"].apply(Metabolite)\n", - " df_test[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + " df_test['Metabolite'] = df_test['SMILES'].apply(Metabolite)\n", + " df_test['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - " covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - " covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - " rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + " covariate_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + " covariate_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + " rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", "\n", - " df_test[\"Metabolite\"].apply(\n", + " df_test['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", " )\n", - " df_test.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", + " df_test.apply(lambda x: x['Metabolite'].set_id(x['group_id']), axis=1)\n", "\n", " # df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", " df_test.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(\n", - " x[\"summary\"], covariate_encoder, rt_encoder\n", + " lambda x: x['Metabolite'].add_metadata(\n", + " x['summary'], covariate_encoder, rt_encoder\n", " ),\n", " axis=1,\n", " )\n", " df_train.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], process_metadata=False),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], process_metadata=False),\n", " axis=1,\n", " )\n", "\n", - " df_test[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + " df_test['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", " df_test.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=x['ppm_peak_tolerance']\n", " ),\n", " axis=1,\n", " )\n", @@ -363,30 +362,30 @@ "metadata": {}, "outputs": [], "source": [ - "casmi16_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv\"\n", - "casmi22_path = f\"{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv\"\n", + "casmi16_path = f'{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv'\n", + "casmi22_path = f'{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv'\n", "\n", "df_cas = pd.read_csv(casmi16_path, index_col=[0], low_memory=False)\n", "df_cast = pd.read_csv(\n", - " f\"{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv\",\n", + " f'{home}/data/metabolites/CASMI_2016/casmi16t_withCCS.csv',\n", " index_col=[0],\n", " low_memory=False,\n", ") # f\"{home}/data/metabolites/CASMI_2016/casmi16_training_combined.csv\"\n", "df_cas22 = pd.read_csv(casmi22_path, index_col=[0], low_memory=False)\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"Candidates\"]\n", + "dict_columns = ['peaks', 'Candidates']\n", "for col in dict_columns:\n", " df_cas[col] = df_cas[col].apply(ast.literal_eval)\n", " df_cast[col] = df_cast[col].apply(ast.literal_eval)\n", "\n", - "df_cas[\"is_priority\"] = True\n", - "df_cast[\"is_priority\"] = False\n", - "df_cas22[\"peaks\"] = df_cas22[\"peaks\"].apply(ast.literal_eval)\n", - "df_cas22[\"ChallengeNum\"] = df_cas22[\"ChallengeName\"].apply(\n", - " lambda x: int(x.split(\"-\")[-1])\n", + "df_cas['is_priority'] = True\n", + "df_cast['is_priority'] = False\n", + "df_cas22['peaks'] = df_cas22['peaks'].apply(ast.literal_eval)\n", + "df_cas22['ChallengeNum'] = df_cas22['ChallengeName'].apply(\n", + " lambda x: int(x.split('-')[-1])\n", ")\n", - "df_cas22[\"is_priority\"] = (df_cas22[\"ChallengeNum\"] < 250).astype(bool)\n", + "df_cas22['is_priority'] = (df_cas22['ChallengeNum'] < 250).astype(bool)\n", "\n", "\n", "def closest_cfm_ce(CE):\n", @@ -465,37 +464,37 @@ "source": [ "from fiora.MOL.collision_energy import NCE_to_eV\n", "\n", - "df_cas[\"dataset\"] = \"CASMI 16\"\n", - "df_cas[\"RETENTIONTIME\"] = df_cas[\"RTINSECONDS\"] / 60.0\n", - "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cas['dataset'] = 'CASMI 16'\n", + "df_cas['RETENTIONTIME'] = df_cas['RTINSECONDS'] / 60.0\n", + "df_cas['Metabolite'] = df_cas['SMILES'].apply(Metabolite)\n", + "df_cas['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas[\"Metabolite\"].apply(\n", + "df_cas['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "df_cas['CE'] = 20.0 # actually stepped 20/35/50\n", + "df_cas['Instrument_type'] = 'HCD' # CHECK if correct Orbitrap\n", "\n", "metadata_key_map16 = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'PRECURSOR_MZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", "}\n", "\n", - "df_cas[\"summary\"] = df_cas.apply(\n", + "df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", ")\n", "df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder), axis=1\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], covariate_encoder), axis=1\n", ")\n", "\n", "# Fragmentation\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cas['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=300 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=300 * PPM\n", " ),\n", " axis=1,\n", ")\n", @@ -504,51 +503,51 @@ "# CASMI 22\n", "#\n", "\n", - "df_cas22[\"dataset\"] = \"CASMI 22\"\n", - "df_cas22[\"Metabolite\"] = df_cas22[\"SMILES\"].apply(Metabolite)\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cas22['dataset'] = 'CASMI 22'\n", + "df_cas22['Metabolite'] = df_cas22['SMILES'].apply(Metabolite)\n", + "df_cas22['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas22[\"Metabolite\"].apply(\n", + "df_cas22['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cas22[\"CE\"] = df_cas22.apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"precursor_mz\"]), axis=1\n", + "df_cas22['CE'] = df_cas22.apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['precursor_mz']), axis=1\n", ")\n", - "df_cas22 = df_cas22[df_cas22[\"CE\"] < CE_upper_limit]\n", - "df_cas22 = df_cas22[df_cas22[\"CE\"] > 0]\n", + "df_cas22 = df_cas22[df_cas22['CE'] < CE_upper_limit]\n", + "df_cas22 = df_cas22[df_cas22['CE'] > 0]\n", "# df_cas22 = df_cas22[df_cas22.is_priority]\n", "\n", "metadata_key_map22 = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"precursor_mz\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"ChallengeRT\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'precursor_mz',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'ChallengeRT',\n", "}\n", "\n", - "df_cas22[\"summary\"] = df_cas22.apply(\n", + "df_cas22['summary'] = df_cas22.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map22.items()}, axis=1\n", ")\n", "df_cas22.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], covariate_encoder, rt_encoder),\n", " axis=1,\n", ")\n", "\n", "# Fragmentation\n", - "df_cas22[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cas22['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "df_cas22.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=300 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=300 * PPM\n", " ),\n", " axis=1,\n", ") # Optional: use mz_cut instead\n", "\n", "df_cas22 = df_cas22.reset_index()\n", "\n", - "df_cas22[\"library\"] = \"CASMI-22\"\n", - "df_cas22[\"RETENTIONTIME\"] = df_cas22[\"ChallengeRT\"] # \"RT_min\"\n", - "df_cas22[\"cfm_CE\"] = df_cas22[\"CE\"].apply(closest_cfm_ce)\n", - "df_cas22[[\"NCE\", \"CE\", \"cfm_CE\"]].head(3)" + "df_cas22['library'] = 'CASMI-22'\n", + "df_cas22['RETENTIONTIME'] = df_cas22['ChallengeRT'] # \"RT_min\"\n", + "df_cas22['cfm_CE'] = df_cas22['CE'].apply(closest_cfm_ce)\n", + "df_cas22[['NCE', 'CE', 'cfm_CE']].head(3)" ] }, { @@ -557,39 +556,37 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MOL.collision_energy import NCE_to_eV\n", - "\n", - "df_cast[\"dataset\"] = \"CASMI 16 Training\"\n", - "df_cast[\"RETENTIONTIME\"] = df_cast[\"RTINSECONDS\"] / 60.0\n", - "df_cast[\"Metabolite\"] = df_cast[\"SMILES\"].apply(Metabolite)\n", - "df_cast[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cast['dataset'] = 'CASMI 16 Training'\n", + "df_cast['RETENTIONTIME'] = df_cast['RTINSECONDS'] / 60.0\n", + "df_cast['Metabolite'] = df_cast['SMILES'].apply(Metabolite)\n", + "df_cast['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cast[\"Metabolite\"].apply(\n", + "df_cast['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cast[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cast[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "df_cast['CE'] = 20.0 # actually stepped 20/35/50\n", + "df_cast['Instrument_type'] = 'HCD' # CHECK if correct Orbitrap\n", "\n", "metadata_key_map16 = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'PRECURSOR_MZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", "}\n", "\n", - "df_cast[\"summary\"] = df_cast.apply(\n", + "df_cast['summary'] = df_cast.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", ")\n", "df_cast.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder), axis=1\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], covariate_encoder), axis=1\n", ")\n", "\n", "# Fragmentation\n", - "df_cast[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cast['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "_ = df_cast.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=300 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=300 * PPM\n", " ),\n", " axis=1,\n", ")" @@ -617,15 +614,16 @@ } ], "source": [ - "from fiora.GNN.Trainer import Trainer\n", "import torch_geometric as geom\n", "\n", + "from fiora.GNN.Trainer import Trainer\n", + "\n", "if torch.cuda.is_available():\n", - " dev = \"cuda:0\"\n", + " dev = 'cuda:0'\n", "else:\n", - " dev = \"cpu\"\n", + " dev = 'cpu'\n", "\n", - "print(f\"Running on device: {dev}\")" + "print(f'Running on device: {dev}')" ] }, { @@ -708,7 +706,7 @@ "metadata": {}, "outputs": [], "source": [ - "np.seterr(invalid=\"ignore\")\n", + "np.seterr(invalid='ignore')\n", "\n", "\n", "def simulate_all(model, DF):\n", @@ -733,137 +731,136 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MOL.collision_energy import NCE_to_eV\n", + "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "from fiora.MS.spectral_scores import (\n", + " reweighted_dot,\n", " spectral_cosine,\n", " spectral_reflection_cosine,\n", - " reweighted_dot,\n", ")\n", - "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", "\n", "def test_cas16(model, df_cas=df_cas, ignore_RT=False):\n", "\n", " # Predict spectra for first NCE step\n", - " df_cas[\"NCE\"] = 20.0\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 20.0\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step1_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(\n", - " x[\"summary\"], covariate_encoder, rt_encoder\n", + " lambda x: x['Metabolite'].add_metadata(\n", + " x['summary'], covariate_encoder, rt_encoder\n", " ),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_20')\n", "\n", " # Predict spectra for last NCE step\n", - " df_cas[\"NCE\"] = 50.0\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 50.0\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step3_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(\n", - " x[\"summary\"], covariate_encoder, rt_encoder\n", + " lambda x: x['Metabolite'].add_metadata(\n", + " x['summary'], covariate_encoder, rt_encoder\n", " ),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_50')\n", "\n", " # Predict spectra for middle (average) NCE step (doing this last makes sure Metabolite metadata references the average case)\n", - " df_cas[\"NCE\"] = 35.0\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 35.0\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step2_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(\n", - " x[\"summary\"], covariate_encoder, rt_encoder\n", + " lambda x: x['Metabolite'].add_metadata(\n", + " x['summary'], covariate_encoder, rt_encoder\n", " ),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_35')\n", "\n", - " df_cas[\"avg_CE\"] = (\n", - " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " df_cas['avg_CE'] = (\n", + " df_cas['step1_CE'] + df_cas['step2_CE'] + df_cas['step3_CE']\n", " ) / 3\n", - " df_cas[\"CE\"] = df_cas[\"avg_CE\"]\n", + " df_cas['CE'] = df_cas['avg_CE']\n", "\n", - " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " df_cas['merged_peaks'] = df_cas.apply(\n", " lambda x: merge_annotated_spectrum(\n", - " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", - " x[\"sim_peaks_50\"],\n", + " merge_annotated_spectrum(x['sim_peaks_20'], x['sim_peaks_35']),\n", + " x['sim_peaks_50'],\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"sim_peaks\"] = df_cas[\"merged_peaks\"]\n", - " df_cas[\"merged_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['sim_peaks'] = df_cas['merged_peaks']\n", + " df_cas['merged_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " df_cas['merged_sqrt_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks'], transform=np.sqrt),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_sqrt_bias\"] = df_cas.apply(\n", + " df_cas['merged_sqrt_bias'] = df_cas.apply(\n", " lambda x: spectral_cosine(\n", - " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt, with_bias=True\n", + " x['peaks'], x['merged_peaks'], transform=np.sqrt, with_bias=True\n", " )[1],\n", " axis=1,\n", " )\n", - " df_cas[\"merged_sqrt_cosine_wo_precursor\"] = df_cas.apply(\n", + " df_cas['merged_sqrt_cosine_wo_precursor'] = df_cas.apply(\n", " lambda x: spectral_cosine(\n", - " x[\"peaks\"],\n", - " x[\"merged_peaks\"],\n", + " x['peaks'],\n", + " x['merged_peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " x[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=x['Metabolite'].get_theoretical_precursor_mz(\n", + " x['Metabolite'].metadata['precursor_mode']\n", " ),\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " df_cas['merged_refl_cosine'] = df_cas.apply(\n", " lambda x: spectral_reflection_cosine(\n", - " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " x['peaks'], x['merged_peaks'], transform=np.sqrt\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_steins\"] = df_cas.apply(\n", - " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['merged_steins'] = df_cas.apply(\n", + " lambda x: reweighted_dot(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"spectral_cosine\"] = df_cas[\"merged_cosine\"] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", - " \"merged_sqrt_cosine\"\n", + " df_cas['spectral_cosine'] = df_cas['merged_cosine'] # just remember it is merged\n", + " df_cas['spectral_sqrt_cosine'] = df_cas[\n", + " 'merged_sqrt_cosine'\n", " ] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine_wo_prec\"] = df_cas[\n", - " \"merged_sqrt_cosine_wo_precursor\"\n", + " df_cas['spectral_sqrt_cosine_wo_prec'] = df_cas[\n", + " 'merged_sqrt_cosine_wo_precursor'\n", " ] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine_avg\"] = (\n", - " df_cas[\"spectral_sqrt_cosine\"] + df_cas[\"spectral_sqrt_cosine_wo_prec\"]\n", + " df_cas['spectral_sqrt_cosine_avg'] = (\n", + " df_cas['spectral_sqrt_cosine'] + df_cas['spectral_sqrt_cosine_wo_prec']\n", " ) / 2.0\n", - " df_cas[\"spectral_sqrt_bias\"] = df_cas[\n", - " \"merged_sqrt_bias\"\n", + " df_cas['spectral_sqrt_bias'] = df_cas[\n", + " 'merged_sqrt_bias'\n", " ] # just remember it is merged\n", "\n", - " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", + " df_cas['coverage'] = df_cas['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", " if not ignore_RT:\n", - " df_cas[\"RT_pred\"] = df_cas[\"RT_pred_35\"]\n", - " df_cas[\"RT_dif\"] = df_cas[\"RT_dif_35\"]\n", - " df_cas[\"CCS_pred\"] = df_cas[\"CCS_pred_35\"]\n", + " df_cas['RT_pred'] = df_cas['RT_pred_35']\n", + " df_cas['RT_dif'] = df_cas['RT_dif_35']\n", + " df_cas['CCS_pred'] = df_cas['CCS_pred_35']\n", "\n", - " df_cas[\"library\"] = \"CASMI-16\"\n", + " df_cas['library'] = 'CASMI-16'\n", "\n", - " df_cas[\"cfm_CE\"] = df_cas[\"avg_CE\"].apply(closest_cfm_ce)\n", + " df_cas['cfm_CE'] = df_cas['avg_CE'].apply(closest_cfm_ce)\n", "\n", " return df_cas" ] @@ -15771,14 +15768,14 @@ } ], "source": [ - "print(f\"Testing the model\")\n", - "np.seterr(invalid=\"ignore\")\n", + "print('Testing the model')\n", + "np.seterr(invalid='ignore')\n", "df_test = test_model(model, df_test)\n", "df_msnlib_test = test_model(model, df_msnlib_test)\n", "df_cas = test_cas16(model, ignore_RT=True)\n", "df_cast = test_cas16(model, df_cas=df_cast, ignore_RT=True)\n", "df_cas22 = test_model(model, df_cas22)\n", - "print(\"Done\")" + "print('Done')" ] }, { @@ -15855,50 +15852,50 @@ "import fiora.IO.cfmReader as cfmReader\n", "\n", "cf = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/msnlib_test_split_negative_predictions.txt\",\n", + " f'{home}/data/metabolites/cfm-id/msnlib_test_split_negative_predictions.txt',\n", " as_df=True,\n", ")\n", "cf_p = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/msnlib_test_split_positive_predictions.txt\",\n", + " f'{home}/data/metabolites/cfm-id/msnlib_test_split_positive_predictions.txt',\n", " as_df=True,\n", ")\n", - "cf[\"ion_type\"] = \"negative\"\n", - "cf_p[\"ion_type\"] = \"positive\"\n", + "cf['ion_type'] = 'negative'\n", + "cf_p['ion_type'] = 'positive'\n", "cf = pd.concat([cf, cf_p])\n", - "cf[\"#ID\"] = cf[\"#ID\"].astype(int)\n", - "df_msnlib_test[\"cfm_CE\"] = df_msnlib_test[\"CE\"].apply(closest_cfm_ce)\n", - "df_msnlib_test[\"cfm_peaks\"] = None\n", - "df_msnlib_test[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"cfm_refl_cosine\"]] = np.nan\n", + "cf['#ID'] = cf['#ID'].astype(int)\n", + "df_msnlib_test['cfm_CE'] = df_msnlib_test['CE'].apply(closest_cfm_ce)\n", + "df_msnlib_test['cfm_peaks'] = None\n", + "df_msnlib_test[['cfm_cosine', 'cfm_sqrt_cosine', 'cfm_refl_cosine']] = np.nan\n", "for i, data in df_msnlib_test.iterrows():\n", - " group_id = int(data[\"group_id\"])\n", - " precursor_type = data[\"Precursor_type\"]\n", + " group_id = int(data['group_id'])\n", + " precursor_type = data['Precursor_type']\n", "\n", - " if len(cf[(cf[\"#ID\"] == group_id) & (cf[\"Precursor_type\"] == precursor_type)]) < 1:\n", - " print(f\"{group_id} not found in CFM-ID results. Skipping.\")\n", + " if len(cf[(cf['#ID'] == group_id) & (cf['Precursor_type'] == precursor_type)]) < 1:\n", + " print(f'{group_id} not found in CFM-ID results. Skipping.')\n", " continue\n", "\n", " cfm_data = cf[\n", - " (cf[\"#ID\"] == group_id) & (cf[\"Precursor_type\"] == precursor_type)\n", + " (cf['#ID'] == group_id) & (cf['Precursor_type'] == precursor_type)\n", " ].iloc[0]\n", "\n", - " cfm_peaks = cfm_data[\"peaks\" + data[\"cfm_CE\"]] # find best reference CE\n", - " df_msnlib_test.at[i, \"cfm_peaks\"] = cfm_peaks\n", - " df_msnlib_test.at[i, \"cfm_cosine\"] = spectral_cosine(data[\"peaks\"], cfm_peaks)\n", - " df_msnlib_test.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", - " data[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " cfm_peaks = cfm_data['peaks' + data['cfm_CE']] # find best reference CE\n", + " df_msnlib_test.at[i, 'cfm_peaks'] = cfm_peaks\n", + " df_msnlib_test.at[i, 'cfm_cosine'] = spectral_cosine(data['peaks'], cfm_peaks)\n", + " df_msnlib_test.at[i, 'cfm_sqrt_cosine'] = spectral_cosine(\n", + " data['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_msnlib_test.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " data[\"peaks\"],\n", + " df_msnlib_test.at[i, 'cfm_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " data['peaks'],\n", " cfm_peaks,\n", " transform=np.sqrt,\n", - " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=data['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=data['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_msnlib_test.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", - " data[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " df_msnlib_test.at[i, 'cfm_refl_cosine'] = spectral_reflection_cosine(\n", + " data['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_msnlib_test.at[i, \"cfm_steins\"] = reweighted_dot(data[\"peaks\"], cfm_peaks)" + " df_msnlib_test.at[i, 'cfm_steins'] = reweighted_dot(data['peaks'], cfm_peaks)" ] }, { @@ -15919,46 +15916,47 @@ ], "source": [ "import fiora.IO.cfmReader as cfmReader\n", + "\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", "cf = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi16_negative_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi16_negative_predictions.txt', as_df=True\n", ")\n", "cf_p = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi16_positive_predictions.txt', as_df=True\n", ")\n", "cf = pd.concat([cf, cf_p])\n", - "len(cf[cf[\"#ID\"] == \"Challenge-009\"]) ## missing chalenges\n", - "df_cas[\"cfm_peaks\"] = None\n", - "df_cas[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"cfm_refl_cosine\"]] = np.nan\n", + "len(cf[cf['#ID'] == 'Challenge-009']) ## missing chalenges\n", + "df_cas['cfm_peaks'] = None\n", + "df_cas[['cfm_cosine', 'cfm_sqrt_cosine', 'cfm_refl_cosine']] = np.nan\n", "for i, cas in df_cas.iterrows():\n", - " challenge = cas[\"ChallengeName\"]\n", + " challenge = cas['ChallengeName']\n", "\n", - " if len(cf[cf[\"#ID\"] == challenge]) != 1:\n", - " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", + " if len(cf[cf['#ID'] == challenge]) != 1:\n", + " print(f'{challenge} not found in CFM-ID results. Skipping.')\n", " continue\n", - " cfm_data = cf[cf[\"#ID\"] == challenge].iloc[0]\n", - "\n", - " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", - " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", - " df_cas.at[i, \"cfm_peaks\"] = cfm_peaks\n", - " df_cas.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cas.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " cfm_data = cf[cf['#ID'] == challenge].iloc[0]\n", + "\n", + " if cas['ChallengeName'] != cfm_data['#ID']:\n", + " raise ValueError('Wrong challenge matched')\n", + " cfm_peaks = cfm_data['peaks' + cas['cfm_CE']] # find best reference CE\n", + " df_cas.at[i, 'cfm_peaks'] = cfm_peaks\n", + " df_cas.at[i, 'cfm_cosine'] = spectral_cosine(cas['peaks'], cfm_peaks)\n", + " df_cas.at[i, 'cfm_sqrt_cosine'] = spectral_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cas.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " cas[\"peaks\"],\n", + " df_cas.at[i, 'cfm_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " cas['peaks'],\n", " cfm_peaks,\n", " transform=np.sqrt,\n", - " remove_mz=cas[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=cas[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=cas['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=cas['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_cas.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " df_cas.at[i, 'cfm_refl_cosine'] = spectral_reflection_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cas.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)" + " df_cas.at[i, 'cfm_steins'] = reweighted_dot(cas['peaks'], cfm_peaks)" ] }, { @@ -15979,48 +15977,49 @@ ], "source": [ "import fiora.IO.cfmReader as cfmReader\n", + "\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", "cf = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi16t_negative_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi16t_negative_predictions.txt', as_df=True\n", ")\n", "cf_p = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi16t_positive_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi16t_positive_predictions.txt', as_df=True\n", ")\n", "cf = pd.concat([cf, cf_p])\n", - "len(cf[cf[\"#ID\"] == \"Challenge-009\"]) ## missing chalenges\n", - "df_cast[\"cfm_peaks\"] = None\n", + "len(cf[cf['#ID'] == 'Challenge-009']) ## missing chalenges\n", + "df_cast['cfm_peaks'] = None\n", "df_cast[\n", - " [\"cfm_cosine\", \"cfm_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"cfm_refl_cosine\"]\n", + " ['cfm_cosine', 'cfm_sqrt_cosine', 'ice_sqrt_cosine_wo_prec', 'cfm_refl_cosine']\n", "] = np.nan\n", "for i, cas in df_cast.iterrows():\n", - " challenge = cas[\"ChallengeName\"]\n", + " challenge = cas['ChallengeName']\n", "\n", - " if len(cf[cf[\"#ID\"] == challenge]) != 1:\n", - " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", + " if len(cf[cf['#ID'] == challenge]) != 1:\n", + " print(f'{challenge} not found in CFM-ID results. Skipping.')\n", " continue\n", - " cfm_data = cf[cf[\"#ID\"] == challenge].iloc[0]\n", - "\n", - " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", - " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", - " df_cast.at[i, \"cfm_peaks\"] = cfm_peaks\n", - " df_cast.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cast.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " cfm_data = cf[cf['#ID'] == challenge].iloc[0]\n", + "\n", + " if cas['ChallengeName'] != cfm_data['#ID']:\n", + " raise ValueError('Wrong challenge matched')\n", + " cfm_peaks = cfm_data['peaks' + cas['cfm_CE']] # find best reference CE\n", + " df_cast.at[i, 'cfm_peaks'] = cfm_peaks\n", + " df_cast.at[i, 'cfm_cosine'] = spectral_cosine(cas['peaks'], cfm_peaks)\n", + " df_cast.at[i, 'cfm_sqrt_cosine'] = spectral_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cast.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " cas[\"peaks\"],\n", + " df_cast.at[i, 'cfm_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " cas['peaks'],\n", " cfm_peaks,\n", " transform=np.sqrt,\n", - " remove_mz=cas[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=cas[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=cas['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=cas['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_cast.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " df_cast.at[i, 'cfm_refl_cosine'] = spectral_reflection_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cast.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)" + " df_cast.at[i, 'cfm_steins'] = reweighted_dot(cas['peaks'], cfm_peaks)" ] }, { @@ -16030,48 +16029,49 @@ "outputs": [], "source": [ "import fiora.IO.cfmReader as cfmReader\n", + "\n", "# time CFM-ID 4: -> 12m16,571s\n", "\n", "\n", "cf22 = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi22_negative_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi22_negative_predictions.txt', as_df=True\n", ")\n", "cf22_p = cfmReader.read(\n", - " f\"{home}/data/metabolites/cfm-id/casmi22_positive_predictions.txt\", as_df=True\n", + " f'{home}/data/metabolites/cfm-id/casmi22_positive_predictions.txt', as_df=True\n", ")\n", "cf22 = pd.concat([cf22, cf22_p])\n", - "df_cas22[\"cfm_peaks\"] = None\n", - "df_cas22[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"cfm_refl_cosine\"]] = np.nan\n", + "df_cas22['cfm_peaks'] = None\n", + "df_cas22[['cfm_cosine', 'cfm_sqrt_cosine', 'cfm_refl_cosine']] = np.nan\n", "for i, cas in df_cas22.iterrows():\n", - " challenge = cas[\"ChallengeName\"]\n", + " challenge = cas['ChallengeName']\n", "\n", - " if len(cf22[cf22[\"#ID\"] == challenge]) != 1:\n", - " print(f\"{challenge} not found in CFM-ID results. Skipping.\")\n", + " if len(cf22[cf22['#ID'] == challenge]) != 1:\n", + " print(f'{challenge} not found in CFM-ID results. Skipping.')\n", " continue\n", - " cfm_data = cf22[cf22[\"#ID\"] == challenge].iloc[0]\n", - "\n", - " if cas[\"ChallengeName\"] != cfm_data[\"#ID\"]:\n", - " raise ValueError(\"Wrong challenge matched\")\n", - " cfm_peaks = cfm_data[\"peaks\" + cas[\"cfm_CE\"]] # find best reference CE\n", - " df_cas22.at[i, \"cfm_peaks\"] = cfm_peaks\n", - " df_cas22.at[i, \"cfm_cosine\"] = spectral_cosine(cas[\"peaks\"], cfm_peaks)\n", - " df_cas22.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " cfm_data = cf22[cf22['#ID'] == challenge].iloc[0]\n", + "\n", + " if cas['ChallengeName'] != cfm_data['#ID']:\n", + " raise ValueError('Wrong challenge matched')\n", + " cfm_peaks = cfm_data['peaks' + cas['cfm_CE']] # find best reference CE\n", + " df_cas22.at[i, 'cfm_peaks'] = cfm_peaks\n", + " df_cas22.at[i, 'cfm_cosine'] = spectral_cosine(cas['peaks'], cfm_peaks)\n", + " df_cas22.at[i, 'cfm_sqrt_cosine'] = spectral_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cas22.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " cas[\"peaks\"],\n", + " df_cas22.at[i, 'cfm_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " cas['peaks'],\n", " cfm_peaks,\n", " transform=np.sqrt,\n", - " remove_mz=cas[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=cas[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=cas['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=cas['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_cas22.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", - " cas[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " df_cas22.at[i, 'cfm_refl_cosine'] = spectral_reflection_cosine(\n", + " cas['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_cas22.at[i, \"cfm_steins\"] = reweighted_dot(cas[\"peaks\"], cfm_peaks)\n", + " df_cas22.at[i, 'cfm_steins'] = reweighted_dot(cas['peaks'], cfm_peaks)\n", "\n", - "df_cas22[\"is_priority\"] = df_cas22[\"is_priority\"].astype(bool)" + "df_cas22['is_priority'] = df_cas22['is_priority'].astype(bool)" ] }, { @@ -16080,9 +16080,9 @@ "metadata": {}, "outputs": [], "source": [ - "ex_smiles = \"CC(NC(=O)CC1=CNC2=C1C=CC=C2)C(O)=O\"\n", + "ex_smiles = 'CC(NC(=O)CC1=CNC2=C1C=CC=C2)C(O)=O'\n", "ex_metabolite = Metabolite(ex_smiles)\n", - "ex_compound_id = df_test[df_test[\"Metabolite\"] == ex_metabolite][\"group_id\"].iloc[0]" + "ex_compound_id = df_test[df_test['Metabolite'] == ex_metabolite]['group_id'].iloc[0]" ] }, { @@ -16130,38 +16130,38 @@ "\n", "reset_matplotlib()\n", "spec_df = {}\n", - "for i, data in df_test[df_test[\"group_id\"] == ex_compound_id].iterrows():\n", - " cosine = data[\"spectral_sqrt_cosine\"]\n", - " name = data[\"Name\"]\n", + "for i, data in df_test[df_test['group_id'] == ex_compound_id].iterrows():\n", + " cosine = data['spectral_sqrt_cosine']\n", + " name = data['Name']\n", " # t3 = data[\"tanimoto3\"]\n", - " print(f\"{name} ({i}): cosine {cosine:0.2}\")\n", + " print(f'{name} ({i}): cosine {cosine:0.2}')\n", " # print(f\"{name} ({i}): cosine {t3}\") only possible after tanimoto calculation below\n", "\n", " fig, axs = plt.subplots(\n", - " 1, 2, figsize=(12.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(12.8, 4.2), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", " )\n", - " img = data[\"Metabolite\"].draw(ax=axs[0])\n", + " img = data['Metabolite'].draw(ax=axs[0])\n", "\n", " # axs[0].grid(False)\n", " axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", " )\n", - " axs[0].set_title(data[\"Name\"])\n", + " axs[0].set_title(data['Name'])\n", " # axs[0].imshow(img)\n", " # axs[0].axis(\"off\")\n", " # sv.plot_spectrum(example, ax=axs[1])\n", - " prec = data[\"Precursor_type\"]\n", + " prec = data['Precursor_type']\n", " spec_df.update(\n", " {\n", - " f\"Experimental m/z {prec}\": data[\"peaks\"][\"mz\"],\n", - " f\"Experimental intensity {prec}\": data[\"peaks\"][\"intensity\"],\n", - " f\"Fiora m/z {prec}\": data[\"sim_peaks\"][\"mz\"],\n", - " f\"Fiora intensity {prec}\": data[\"sim_peaks\"][\"intensity\"],\n", + " f'Experimental m/z {prec}': data['peaks']['mz'],\n", + " f'Experimental intensity {prec}': data['peaks']['intensity'],\n", + " f'Fiora m/z {prec}': data['sim_peaks']['mz'],\n", + " f'Fiora intensity {prec}': data['sim_peaks']['intensity'],\n", " }\n", " )\n", "\n", " ax = sv.plot_spectrum(\n", - " data, {\"peaks\": data[\"sim_peaks\"]}, ax=axs[1], highlight_matches=False\n", + " data, {'peaks': data['sim_peaks']}, ax=axs[1], highlight_matches=False\n", " )\n", " # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -16175,8 +16175,8 @@ "metadata": {}, "outputs": [], "source": [ - "spec_df = pd.DataFrame.from_dict(spec_df, orient=\"index\").transpose()\n", - "spec_df = spec_df.fillna(\"-\")" + "spec_df = pd.DataFrame.from_dict(spec_df, orient='index').transpose()\n", + "spec_df = spec_df.fillna('-')" ] }, { @@ -16205,40 +16205,40 @@ "outputs": [], "source": [ "paths = {\n", - " \"[M-H]-\": f\"{home}/data/metabolites/cfm-id/test_pred_neg/\",\n", - " \"[M+H]+\": f\"{home}/data/metabolites/cfm-id/test_pred_pos/\",\n", + " '[M-H]-': f'{home}/data/metabolites/cfm-id/test_pred_neg/',\n", + " '[M+H]+': f'{home}/data/metabolites/cfm-id/test_pred_pos/',\n", "}\n", - "df_test[\"cfm_CE\"] = df_test[\"CE\"].apply(closest_cfm_ce)\n", - "df_test[\"cfm_peaks\"] = None\n", - "df_test[[\"cfm_cosine\", \"cfm_sqrt_cosine\", \"cfm_refl_cosine\"]] = np.nan\n", + "df_test['cfm_CE'] = df_test['CE'].apply(closest_cfm_ce)\n", + "df_test['cfm_peaks'] = None\n", + "df_test[['cfm_cosine', 'cfm_sqrt_cosine', 'cfm_refl_cosine']] = np.nan\n", "for i, data in df_test.iterrows():\n", - " group_id = data[\"group_id\"]\n", - " p = paths[data[\"Precursor_type\"]] + str(int(group_id)) + \".txt\"\n", + " group_id = data['group_id']\n", + " p = paths[data['Precursor_type']] + str(int(group_id)) + '.txt'\n", " cfm_data = cfmReader.read(p, as_df=True)\n", "\n", " # TODO Check smiles / MOL\n", " if cfm_data.shape == (0, 0): # Not predicted by CFM-ID\n", " continue\n", "\n", - " cfm_peaks = cfm_data[\"peaks\" + data[\"cfm_CE\"]].iloc[0]\n", + " cfm_peaks = cfm_data['peaks' + data['cfm_CE']].iloc[0]\n", "\n", - " df_test.at[i, \"cfm_peaks\"] = cfm_peaks\n", - " df_test.at[i, \"cfm_cosine\"] = spectral_cosine(data[\"peaks\"], cfm_peaks)\n", - " df_test.at[i, \"cfm_sqrt_cosine\"] = spectral_cosine(\n", - " data[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " df_test.at[i, 'cfm_peaks'] = cfm_peaks\n", + " df_test.at[i, 'cfm_cosine'] = spectral_cosine(data['peaks'], cfm_peaks)\n", + " df_test.at[i, 'cfm_sqrt_cosine'] = spectral_cosine(\n", + " data['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_test.at[i, \"cfm_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " data[\"peaks\"],\n", + " df_test.at[i, 'cfm_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " data['peaks'],\n", " cfm_peaks,\n", " transform=np.sqrt,\n", - " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=data['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=data['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_test.at[i, \"cfm_refl_cosine\"] = spectral_reflection_cosine(\n", - " data[\"peaks\"], cfm_peaks, transform=np.sqrt\n", + " df_test.at[i, 'cfm_refl_cosine'] = spectral_reflection_cosine(\n", + " data['peaks'], cfm_peaks, transform=np.sqrt\n", " )\n", - " df_test.at[i, \"cfm_steins\"] = reweighted_dot(data[\"peaks\"], cfm_peaks)" + " df_test.at[i, 'cfm_steins'] = reweighted_dot(data['peaks'], cfm_peaks)" ] }, { @@ -16261,87 +16261,87 @@ "\n", "# CFM-ID query input\n", "if False:\n", - " file = f\"{home}/data/metabolites/cfm-id/test_split_negative_solutions_cfm.txt\"\n", - " df_test[\"group_id\"] = df_test[\"group_id\"].astype(int)\n", - " df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][\n", - " [\"group_id\", \"SMILES\"]\n", - " ].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", - " file = file.replace(\"negative\", \"positive\")\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " [\"group_id\", \"SMILES\"]\n", - " ].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", + " file = f'{home}/data/metabolites/cfm-id/test_split_negative_solutions_cfm.txt'\n", + " df_test['group_id'] = df_test['group_id'].astype(int)\n", + " df_test[df_test['Precursor_type'] == '[M-H]-'][\n", + " ['group_id', 'SMILES']\n", + " ].drop_duplicates('group_id').to_csv(file, index=False, header=False, sep=' ')\n", + " file = file.replace('negative', 'positive')\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'][\n", + " ['group_id', 'SMILES']\n", + " ].drop_duplicates('group_id').to_csv(file, index=False, header=False, sep=' ')\n", "\n", "if False:\n", " file = (\n", - " f\"{home}/data/metabolites/cfm-id/msnlib_test_split_negative_solutions_cfm.txt\"\n", + " f'{home}/data/metabolites/cfm-id/msnlib_test_split_negative_solutions_cfm.txt'\n", " )\n", - " df_msnlib_test[\"group_id\"] = df_msnlib_test[\"group_id\"].astype(int)\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][\n", - " [\"group_id\", \"SMILES\"]\n", - " ].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", - " file = file.replace(\"negative\", \"positive\")\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " [\"group_id\", \"SMILES\"]\n", - " ].drop_duplicates(\"group_id\").to_csv(file, index=False, header=False, sep=\" \")\n", + " df_msnlib_test['group_id'] = df_msnlib_test['group_id'].astype(int)\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M-H]-'][\n", + " ['group_id', 'SMILES']\n", + " ].drop_duplicates('group_id').to_csv(file, index=False, header=False, sep=' ')\n", + " file = file.replace('negative', 'positive')\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M+H]+'][\n", + " ['group_id', 'SMILES']\n", + " ].drop_duplicates('group_id').to_csv(file, index=False, header=False, sep=' ')\n", "\n", "# ICEBERG/SCARF training/testing input\n", "if False:\n", " # OLD # df_test[\"idx\"] = [f\"spec{i}\" for i,_ in df_test.iterrows()]\n", - " df_test[\"num\"] = df_test.groupby(\"group_id\").cumcount() + 1\n", - " df_test[\"idx\"] = (\n", - " \"spec\"\n", - " + df_test[\"group_id\"].astype(int).astype(str)\n", - " + \"_\"\n", - " + df_test[\"num\"].astype(str)\n", + " df_test['num'] = df_test.groupby('group_id').cumcount() + 1\n", + " df_test['idx'] = (\n", + " 'spec'\n", + " + df_test['group_id'].astype(int).astype(str)\n", + " + '_'\n", + " + df_test['num'].astype(str)\n", " )\n", "\n", " # df_train[\"dataset_label\"] = \"df_test\"\n", " label_map = {\n", - " \"idx\": \"spec\",\n", - " \"Name\": \"name\",\n", - " \"Precursor_type\": \"ionization\",\n", - " \"SMILES\": \"smiles\",\n", - " \"InChIKey\": \"inchikey\",\n", + " 'idx': 'spec',\n", + " 'Name': 'name',\n", + " 'Precursor_type': 'ionization',\n", + " 'SMILES': 'smiles',\n", + " 'InChIKey': 'inchikey',\n", " }\n", - " df_test[\"formula\"] = df_test[\"Metabolite\"].apply(lambda x: x.Formula)\n", - " df_test[\"InChIKey\"] = df_test[\"Metabolite\"].apply(lambda x: x.InChIKey)\n", + " df_test['formula'] = df_test['Metabolite'].apply(lambda x: x.Formula)\n", + " df_test['InChIKey'] = df_test['Metabolite'].apply(lambda x: x.InChIKey)\n", " # import fiora.IO.mspredWriter as mspredWriter WRITER bugged?\n", " # mspredWriter.write_labels(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], f\"{home}/data/metabolites/ms-pred/df_test.tsv\", label_map=label_map)\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"].rename(columns=label_map)[\n", - " [\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]\n", - " ].to_csv(f\"{home}/data/metabolites/ms-pred/df_test.tsv\", index=False, sep=\"\\t\")\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'].rename(columns=label_map)[\n", + " ['dataset', 'spec', 'name', 'ionization', 'formula', 'smiles', 'inchikey']\n", + " ].to_csv(f'{home}/data/metabolites/ms-pred/df_test.tsv', index=False, sep='\\t')\n", "\n", "if False: # MSnLib\n", - " df_msnlib_test[\"num\"] = df_msnlib_test.groupby(\"group_id\").cumcount() + 1\n", - " df_msnlib_test[\"idx\"] = (\n", - " \"spec\"\n", - " + df_msnlib_test[\"group_id\"].astype(int).astype(str)\n", - " + \"_\"\n", - " + df_msnlib_test[\"num\"].astype(str)\n", + " df_msnlib_test['num'] = df_msnlib_test.groupby('group_id').cumcount() + 1\n", + " df_msnlib_test['idx'] = (\n", + " 'spec'\n", + " + df_msnlib_test['group_id'].astype(int).astype(str)\n", + " + '_'\n", + " + df_msnlib_test['num'].astype(str)\n", " )\n", "\n", " # df_train[\"dataset_label\"] = \"df_msnlib_test\"\n", " label_map = {\n", - " \"idx\": \"spec\",\n", - " \"NAME\": \"name\",\n", - " \"Precursor_type\": \"ionization\",\n", - " \"SMILES\": \"smiles\",\n", - " \"INCHIAUX\": \"inchikey\",\n", + " 'idx': 'spec',\n", + " 'NAME': 'name',\n", + " 'Precursor_type': 'ionization',\n", + " 'SMILES': 'smiles',\n", + " 'INCHIAUX': 'inchikey',\n", " }\n", "\n", - " df_msnlib_test[\"formula\"] = df_msnlib_test[\"Metabolite\"].apply(lambda x: x.Formula)\n", - " df_msnlib_test[\"InChIKey\"] = df_msnlib_test[\"Metabolite\"].apply(\n", + " df_msnlib_test['formula'] = df_msnlib_test['Metabolite'].apply(lambda x: x.Formula)\n", + " df_msnlib_test['InChIKey'] = df_msnlib_test['Metabolite'].apply(\n", " lambda x: x.InChIKey\n", " )\n", " # import fiora.IO.mspredWriter as mspredWriter WRITER bugged?\n", " # mspredWriter.write_labels(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], f\"{home}/data/metabolites/ms-pred/df_test.tsv\", label_map=label_map)\n", "\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"].rename(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M+H]+'].rename(\n", " columns=label_map\n", " )[\n", - " [\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]\n", + " ['dataset', 'spec', 'name', 'ionization', 'formula', 'smiles', 'inchikey']\n", " ].to_csv(\n", - " f\"{home}/data/metabolites/ms-pred/df_msnlib_test.tsv\", index=False, sep=\"\\t\"\n", + " f'{home}/data/metabolites/ms-pred/df_msnlib_test.tsv', index=False, sep='\\t'\n", " )\n", "\n", "if False:\n", @@ -16350,20 +16350,20 @@ " # # from rdkit import Chem\n", " # # from rdkit.Chem import rdMolDescriptors\n", " label_map = {\n", - " \"idx\": \"spec\",\n", - " \"Precursor_type\": \"ionization\",\n", - " \"SMILES\": \"smiles\",\n", - " \"InChIKey\": \"inchikey\",\n", + " 'idx': 'spec',\n", + " 'Precursor_type': 'ionization',\n", + " 'SMILES': 'smiles',\n", + " 'InChIKey': 'inchikey',\n", " }\n", - " df_cas22[\"idx\"] = [f\"spec{i}\" for i, _ in df_cas22.iterrows()]\n", - " df_cas22[\"name\"] = \"Unknown\"\n", - " df_cas22[\"InChIKey\"] = df_cas22[\"Metabolite\"].apply(lambda x: x.InChIKey)\n", - " df_cas22[\"formula\"] = df_cas22[\"Metabolite\"].apply(lambda x: x.Formula)\n", + " df_cas22['idx'] = [f'spec{i}' for i, _ in df_cas22.iterrows()]\n", + " df_cas22['name'] = 'Unknown'\n", + " df_cas22['InChIKey'] = df_cas22['Metabolite'].apply(lambda x: x.InChIKey)\n", + " df_cas22['formula'] = df_cas22['Metabolite'].apply(lambda x: x.Formula)\n", "\n", - " output_file = f\"{home}/data/metabolites/ms-pred/casmi22_positive_labels.tsv\"\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"].rename(columns=label_map)[\n", - " [\"dataset\", \"spec\", \"name\", \"formula\", \"ionization\", \"smiles\", \"inchikey\"]\n", - " ].to_csv(output_file, index=False, sep=\"\\t\")\n", + " output_file = f'{home}/data/metabolites/ms-pred/casmi22_positive_labels.tsv'\n", + " df_cas22[df_cas22['Precursor_type'] == '[M+H]+'].rename(columns=label_map)[\n", + " ['dataset', 'spec', 'name', 'formula', 'ionization', 'smiles', 'inchikey']\n", + " ].to_csv(output_file, index=False, sep='\\t')\n", "\n", " # # ### CASMI-16 labels were generated in Casmi16 loader. Avoiding repeat here.\n", "\n", @@ -16508,39 +16508,39 @@ "source": [ "import fiora.IO.mspredReader as mspredReader\n", "\n", - "iceberg_dir = f\"{home}/repos/ms-pred/results/test_out_recovery/casmi16/tree_preds_inten\"\n", + "iceberg_dir = f'{home}/repos/ms-pred/results/test_out_recovery/casmi16/tree_preds_inten'\n", "df_ice = mspredReader.read(iceberg_dir)\n", - "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", + "df_ice = df_ice.rename(columns={'peaks': 'ice_peaks', 'name': 'ice_name'})\n", "\n", "\n", "df_cas = pd.merge(\n", " df_cas,\n", - " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", - " left_on=\"ChallengeName\",\n", - " right_on=\"ice_name\",\n", - " how=\"left\",\n", + " df_ice[['ice_name', 'ice_peaks']],\n", + " left_on='ChallengeName',\n", + " right_on='ice_name',\n", + " how='left',\n", ")\n", "\n", "df_cas[\n", - " [\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]\n", + " ['ice_cosine', 'ice_sqrt_cosine', 'ice_sqrt_cosine_wo_prec', 'ice_refl_cosine']\n", "] = np.nan\n", - "for i, data in df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", - " df_cas.at[i, \"ice_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"])\n", - " df_cas.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + "for i, data in df_cas[df_cas['Precursor_type'] == '[M+H]+'].iterrows():\n", + " df_cas.at[i, 'ice_cosine'] = spectral_cosine(data['peaks'], data['ice_peaks'])\n", + " df_cas.at[i, 'ice_sqrt_cosine'] = spectral_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_cas.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " data[\"peaks\"],\n", - " data[\"ice_peaks\"],\n", + " df_cas.at[i, 'ice_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " data['peaks'],\n", + " data['ice_peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=data['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=data['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_cas.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " df_cas.at[i, 'ice_refl_cosine'] = spectral_reflection_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_cas.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])" + " df_cas.at[i, 'ice_steins'] = reweighted_dot(data['peaks'], data['ice_peaks'])" ] }, { @@ -16552,18 +16552,18 @@ "import fiora.IO.mspredReader as mspredReader\n", "\n", "iceberg_dir = (\n", - " f\"{home}/repos/ms-pred/results/test_out_recovery/casmi16t/tree_preds_inten\"\n", + " f'{home}/repos/ms-pred/results/test_out_recovery/casmi16t/tree_preds_inten'\n", ")\n", "df_ice = mspredReader.read(iceberg_dir)\n", - "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", + "df_ice = df_ice.rename(columns={'peaks': 'ice_peaks', 'name': 'ice_name'})\n", "\n", "\n", "df_cast = pd.merge(\n", " df_cast,\n", - " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", - " left_on=\"ChallengeName\",\n", - " right_on=\"ice_name\",\n", - " how=\"left\",\n", + " df_ice[['ice_name', 'ice_peaks']],\n", + " left_on='ChallengeName',\n", + " right_on='ice_name',\n", + " how='left',\n", ")" ] }, @@ -16573,26 +16573,26 @@ "metadata": {}, "outputs": [], "source": [ - "df_cast[[\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_precice_refl_cosine\"]] = (\n", + "df_cast[['ice_cosine', 'ice_sqrt_cosine', 'ice_sqrt_cosine_wo_precice_refl_cosine']] = (\n", " np.nan\n", ")\n", - "for i, data in df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", - " df_cast.at[i, \"ice_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"])\n", - " df_cast.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + "for i, data in df_cast[df_cast['Precursor_type'] == '[M+H]+'].iterrows():\n", + " df_cast.at[i, 'ice_cosine'] = spectral_cosine(data['peaks'], data['ice_peaks'])\n", + " df_cast.at[i, 'ice_sqrt_cosine'] = spectral_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_cast.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " data[\"peaks\"],\n", - " data[\"ice_peaks\"],\n", + " df_cast.at[i, 'ice_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " data['peaks'],\n", + " data['ice_peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=data['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=data['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_cast.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " df_cast.at[i, 'ice_refl_cosine'] = spectral_reflection_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_cast.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])" + " df_cast.at[i, 'ice_steins'] = reweighted_dot(data['peaks'], data['ice_peaks'])" ] }, { @@ -16601,17 +16601,17 @@ "metadata": {}, "outputs": [], "source": [ - "iceberg_dir = f\"{home}/repos/ms-pred/results/test_out_recovery/casmi22/tree_preds_inten\"\n", + "iceberg_dir = f'{home}/repos/ms-pred/results/test_out_recovery/casmi22/tree_preds_inten'\n", "df_ice = mspredReader.read(iceberg_dir)\n", - "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", + "df_ice = df_ice.rename(columns={'peaks': 'ice_peaks', 'name': 'ice_name'})\n", "\n", - "df_cas22[\"idx\"] = [f\"spec{i}\" for i, _ in df_cas22.iterrows()]\n", + "df_cas22['idx'] = [f'spec{i}' for i, _ in df_cas22.iterrows()]\n", "df_cas22 = pd.merge(\n", " df_cas22,\n", - " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", - " left_on=\"idx\",\n", - " right_on=\"ice_name\",\n", - " how=\"left\",\n", + " df_ice[['ice_name', 'ice_peaks']],\n", + " left_on='idx',\n", + " right_on='ice_name',\n", + " how='left',\n", ")" ] }, @@ -16622,27 +16622,27 @@ "outputs": [], "source": [ "df_cas22[\n", - " [\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]\n", + " ['ice_cosine', 'ice_sqrt_cosine', 'ice_sqrt_cosine_wo_prec', 'ice_refl_cosine']\n", "] = np.nan\n", - "for i, data in df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", + "for i, data in df_cas22[df_cas22['Precursor_type'] == '[M+H]+'].iterrows():\n", " # print(i, data[\"ice_peaks\"], data[\"ice_peaks\"] is not np.nan)\n", - " if data[\"ice_peaks\"] is not np.nan:\n", - " df_cas22.at[i, \"ice_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"])\n", - " df_cas22.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " if data['ice_peaks'] is not np.nan:\n", + " df_cas22.at[i, 'ice_cosine'] = spectral_cosine(data['peaks'], data['ice_peaks'])\n", + " df_cas22.at[i, 'ice_sqrt_cosine'] = spectral_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_cas22.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " data[\"peaks\"],\n", - " data[\"ice_peaks\"],\n", + " df_cas22.at[i, 'ice_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " data['peaks'],\n", + " data['ice_peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=data['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=data['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_cas22.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " df_cas22.at[i, 'ice_refl_cosine'] = spectral_reflection_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_cas22.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])" + " df_cas22.at[i, 'ice_steins'] = reweighted_dot(data['peaks'], data['ice_peaks'])" ] }, { @@ -16709,9 +16709,9 @@ } ], "source": [ - "iceberg_dir = f\"{home}/repos/ms-pred/results/test_out_recovery/df_test/tree_preds_inten\"\n", + "iceberg_dir = f'{home}/repos/ms-pred/results/test_out_recovery/df_test/tree_preds_inten'\n", "df_ice = mspredReader.read(iceberg_dir)\n", - "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", + "df_ice = df_ice.rename(columns={'peaks': 'ice_peaks', 'name': 'ice_name'})\n", "# df_ice[\"ice_idx\"] = df_ice[\"ice_name\"].str.extract(r'spec(\\d+)', expand=False).astype(int)\n", "df_ice.head(2)" ] @@ -16738,49 +16738,49 @@ "metadata": {}, "outputs": [], "source": [ - "iceberg_dir = f\"{home}/repos/ms-pred/results/test_out_recovery/df_test/tree_preds_inten\"\n", + "iceberg_dir = f'{home}/repos/ms-pred/results/test_out_recovery/df_test/tree_preds_inten'\n", "df_ice = mspredReader.read(iceberg_dir)\n", - "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", - "\n", - "df_test[\"num\"] = df_test.groupby(\"group_id\").cumcount() + 1\n", - "df_test[\"idx\"] = (\n", - " \"spec\"\n", - " + df_test[\"group_id\"].astype(int).astype(str)\n", - " + \"_\"\n", - " + df_test[\"num\"].astype(str)\n", + "df_ice = df_ice.rename(columns={'peaks': 'ice_peaks', 'name': 'ice_name'})\n", + "\n", + "df_test['num'] = df_test.groupby('group_id').cumcount() + 1\n", + "df_test['idx'] = (\n", + " 'spec'\n", + " + df_test['group_id'].astype(int).astype(str)\n", + " + '_'\n", + " + df_test['num'].astype(str)\n", ") # df_test[\"idx\"] = [f\"spec{i}\" for i,_ in df_test.iterrows()]\n", "ori_idx = df_test.index.copy()\n", "df_test = pd.merge(\n", " df_test,\n", - " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", - " left_on=\"idx\",\n", - " right_on=\"ice_name\",\n", - " how=\"left\",\n", + " df_ice[['ice_name', 'ice_peaks']],\n", + " left_on='idx',\n", + " right_on='ice_name',\n", + " how='left',\n", ")\n", "df_test.index = ori_idx\n", "# df_test.index = df_test[\"idx\"].str.extract(r'spec(\\d+)', expand=False).astype(int) TODO CHECK what happens to the index\n", "\n", "df_test[\n", - " [\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]\n", + " ['ice_cosine', 'ice_sqrt_cosine', 'ice_sqrt_cosine_wo_prec', 'ice_refl_cosine']\n", "] = np.nan\n", - "for i, data in df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", + "for i, data in df_test[df_test['Precursor_type'] == '[M+H]+'].iterrows():\n", " try:\n", - " df_test.at[i, \"ice_cosine\"] = spectral_cosine(data[\"peaks\"], data[\"ice_peaks\"])\n", - " df_test.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " df_test.at[i, 'ice_cosine'] = spectral_cosine(data['peaks'], data['ice_peaks'])\n", + " df_test.at[i, 'ice_sqrt_cosine'] = spectral_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_test.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " data[\"peaks\"],\n", - " data[\"ice_peaks\"],\n", + " df_test.at[i, 'ice_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " data['peaks'],\n", + " data['ice_peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=data['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=data['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_test.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " df_test.at[i, 'ice_refl_cosine'] = spectral_reflection_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_test.at[i, \"ice_steins\"] = reweighted_dot(data[\"peaks\"], data[\"ice_peaks\"])\n", + " df_test.at[i, 'ice_steins'] = reweighted_dot(data['peaks'], data['ice_peaks'])\n", " except:\n", " pass" ] @@ -16792,53 +16792,53 @@ "outputs": [], "source": [ "iceberg_dir = (\n", - " f\"{home}/repos/ms-pred/results/test_out_recovery/df_msnlib_test/tree_preds_inten\"\n", + " f'{home}/repos/ms-pred/results/test_out_recovery/df_msnlib_test/tree_preds_inten'\n", ")\n", "df_ice = mspredReader.read(iceberg_dir)\n", - "df_ice = df_ice.rename(columns={\"peaks\": \"ice_peaks\", \"name\": \"ice_name\"})\n", - "\n", - "df_msnlib_test[\"num\"] = df_msnlib_test.groupby(\"group_id\").cumcount() + 1\n", - "df_msnlib_test[\"idx\"] = (\n", - " \"spec\"\n", - " + df_msnlib_test[\"group_id\"].astype(int).astype(str)\n", - " + \"_\"\n", - " + df_msnlib_test[\"num\"].astype(str)\n", + "df_ice = df_ice.rename(columns={'peaks': 'ice_peaks', 'name': 'ice_name'})\n", + "\n", + "df_msnlib_test['num'] = df_msnlib_test.groupby('group_id').cumcount() + 1\n", + "df_msnlib_test['idx'] = (\n", + " 'spec'\n", + " + df_msnlib_test['group_id'].astype(int).astype(str)\n", + " + '_'\n", + " + df_msnlib_test['num'].astype(str)\n", ") # df_msnlib_test[\"idx\"] = [f\"spec{i}\" for i,_ in df_msnlib_test.iterrows()]\n", "ori_idx = df_msnlib_test.index.copy()\n", "df_msnlib_test = pd.merge(\n", " df_msnlib_test,\n", - " df_ice[[\"ice_name\", \"ice_peaks\"]],\n", - " left_on=\"idx\",\n", - " right_on=\"ice_name\",\n", - " how=\"left\",\n", + " df_ice[['ice_name', 'ice_peaks']],\n", + " left_on='idx',\n", + " right_on='ice_name',\n", + " how='left',\n", ")\n", "df_msnlib_test.index = ori_idx\n", "# df_msnlib_test.index = df_msnlib_test[\"idx\"].str.extract(r'spec(\\d+)', expand=False).astype(int) TODO CHECK what happens to the index\n", "\n", "df_msnlib_test[\n", - " [\"ice_cosine\", \"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\", \"ice_refl_cosine\"]\n", + " ['ice_cosine', 'ice_sqrt_cosine', 'ice_sqrt_cosine_wo_prec', 'ice_refl_cosine']\n", "] = np.nan\n", - "for i, data in df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", + "for i, data in df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M+H]+'].iterrows():\n", " try:\n", - " df_msnlib_test.at[i, \"ice_cosine\"] = spectral_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"]\n", + " df_msnlib_test.at[i, 'ice_cosine'] = spectral_cosine(\n", + " data['peaks'], data['ice_peaks']\n", " )\n", - " df_msnlib_test.at[i, \"ice_sqrt_cosine\"] = spectral_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " df_msnlib_test.at[i, 'ice_sqrt_cosine'] = spectral_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_msnlib_test.at[i, \"ice_sqrt_cosine_wo_prec\"] = spectral_cosine(\n", - " data[\"peaks\"],\n", - " data[\"ice_peaks\"],\n", + " df_msnlib_test.at[i, 'ice_sqrt_cosine_wo_prec'] = spectral_cosine(\n", + " data['peaks'],\n", + " data['ice_peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=data[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=data[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=data['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=data['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", - " df_msnlib_test.at[i, \"ice_refl_cosine\"] = spectral_reflection_cosine(\n", - " data[\"peaks\"], data[\"ice_peaks\"], transform=np.sqrt\n", + " df_msnlib_test.at[i, 'ice_refl_cosine'] = spectral_reflection_cosine(\n", + " data['peaks'], data['ice_peaks'], transform=np.sqrt\n", " )\n", - " df_msnlib_test.at[i, \"ice_steins\"] = reweighted_dot(\n", - " data[\"peaks\"], data[\"ice_peaks\"]\n", + " df_msnlib_test.at[i, 'ice_steins'] = reweighted_dot(\n", + " data['peaks'], data['ice_peaks']\n", " )\n", " except:\n", " pass" @@ -16879,10 +16879,10 @@ "i = 81\n", "ax = sv.plot_spectrum(\n", " df_cas.iloc[i],\n", - " {\"peaks\": df_cas.iloc[i][\"ice_peaks\"]},\n", - " title=df_cas.iloc[i][\"ChallengeName\"]\n", - " + \" vs ICEBERG pred: \"\n", - " + df_cas.iloc[i][\"ice_name\"],\n", + " {'peaks': df_cas.iloc[i]['ice_peaks']},\n", + " title=df_cas.iloc[i]['ChallengeName']\n", + " + ' vs ICEBERG pred: '\n", + " + df_cas.iloc[i]['ice_name'],\n", " highlight_matches=False,\n", " with_grid=False,\n", " ax=ax,\n", @@ -16936,7 +16936,7 @@ } ], "source": [ - "df_msnlib_test.groupby(\"Precursor_type\")[\"group_id\"].nunique()" + "df_msnlib_test.groupby('Precursor_type')['group_id'].nunique()" ] }, { @@ -17061,98 +17061,98 @@ ], "source": [ "# Default score\n", - "score = \"spectral_sqrt_cosine\"\n", + "score = 'spectral_sqrt_cosine'\n", "avg_func = np.median\n", "\n", "fiora_res = {\n", - " \"model\": \"Fiora\",\n", - " \"Test+\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score]),\n", - " \"Test-\": avg_func(df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][score]),\n", - " \"MSnLib+\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " 'model': 'Fiora',\n", + " 'Test+': avg_func(df_test[df_test['Precursor_type'] == '[M+H]+'][score]),\n", + " 'Test-': avg_func(df_test[df_test['Precursor_type'] == '[M-H]-'][score]),\n", + " 'MSnLib+': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M+H]+'][score].fillna(0.0)\n", " ),\n", - " \"MSnLib-\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " 'MSnLib-': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M-H]-'][score].fillna(0.0)\n", " ),\n", - " \"CASMI16+\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score]),\n", - " \"CASMI16-\": avg_func(df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score]),\n", - " \"CASMI22+\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score]),\n", - " \"CASMI22-\": avg_func(df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score]),\n", + " 'CASMI16+': avg_func(df_cas[df_cas['Precursor_type'] == '[M+H]+'][score]),\n", + " 'CASMI16-': avg_func(df_cas[df_cas['Precursor_type'] == '[M-H]-'][score]),\n", + " 'CASMI22+': avg_func(df_cas22[df_cas22['Precursor_type'] == '[M+H]+'][score]),\n", + " 'CASMI22-': avg_func(df_cas22[df_cas22['Precursor_type'] == '[M-H]-'][score]),\n", "}\n", "cfm_id = {\n", - " \"model\": \"CFM-ID 4.4.7\",\n", - " \"Test+\": avg_func(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'model': 'CFM-ID 4.4.7',\n", + " 'Test+': avg_func(\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"Test-\": avg_func(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'Test-': avg_func(\n", + " df_test[df_test['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"MSnLib+\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'MSnLib+': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"MSnLib-\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'MSnLib-': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI16+\": avg_func(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI16+': avg_func(\n", + " df_cas[df_cas['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI16-\": avg_func(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI16-': avg_func(\n", + " df_cas[df_cas['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI22+\": avg_func(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI22+': avg_func(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'cfm')\n", " ]\n", " ),\n", - " \"CASMI22-\": avg_func(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI22-': avg_func(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ]\n", " ),\n", "}\n", "ice_res = {\n", - " \"model\": \"ICEBERG\",\n", - " \"Test+\": avg_func(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"ice\")\n", + " 'model': 'ICEBERG',\n", + " 'Test+': avg_func(\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'ice')\n", " ].fillna(0.0)\n", " ),\n", - " \"MSnLib+\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"ice\")\n", + " 'MSnLib+': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'ice')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI16+\": avg_func(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"ice\")\n", + " 'CASMI16+': avg_func(\n", + " df_cas[df_cas['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'ice')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI22+\": avg_func(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"ice\")\n", + " 'CASMI22+': avg_func(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'ice')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI22-\": avg_func(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"ice\")\n", + " 'CASMI22-': avg_func(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'ice')\n", " ].fillna(0.0)\n", " ),\n", "}\n", "\n", "summaryPos = pd.DataFrame([fiora_res, cfm_id, ice_res])\n", - "print(\"Summary test sets\")\n", + "print('Summary test sets')\n", "summaryPos" ] }, @@ -17267,105 +17267,105 @@ } ], "source": [ - "score = \"spectral_sqrt_cosine_wo_prec\"\n", + "score = 'spectral_sqrt_cosine_wo_prec'\n", "avg_func = np.median\n", "\n", "fiora_res = {\n", - " \"model\": \"Fiora\",\n", - " \"Test+\": avg_func(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " 'model': 'Fiora',\n", + " 'Test+': avg_func(\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'][score].fillna(0.0)\n", " ),\n", - " \"Test-\": avg_func(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " 'Test-': avg_func(\n", + " df_test[df_test['Precursor_type'] == '[M-H]-'][score].fillna(0.0)\n", " ),\n", - " \"MSnLib+\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " 'MSnLib+': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M+H]+'][score].fillna(0.0)\n", " ),\n", - " \"MSnLib-\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " 'MSnLib-': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M-H]-'][score].fillna(0.0)\n", " ),\n", - " \"CASMI16+\": avg_func(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " 'CASMI16+': avg_func(\n", + " df_cas[df_cas['Precursor_type'] == '[M+H]+'][score].fillna(0.0)\n", " ),\n", - " \"CASMI16-\": avg_func(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " 'CASMI16-': avg_func(\n", + " df_cas[df_cas['Precursor_type'] == '[M-H]-'][score].fillna(0.0)\n", " ),\n", - " \"CASMI22+\": avg_func(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", + " 'CASMI22+': avg_func(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M+H]+'][score].fillna(0.0)\n", " ),\n", - " \"CASMI22-\": avg_func(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", + " 'CASMI22-': avg_func(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M-H]-'][score].fillna(0.0)\n", " ),\n", "}\n", "cfm_id = {\n", - " \"model\": \"CFM-ID 4.4.7\",\n", - " \"Test+\": avg_func(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'model': 'CFM-ID 4.4.7',\n", + " 'Test+': avg_func(\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"Test-\": avg_func(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'Test-': avg_func(\n", + " df_test[df_test['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"MSnLib+\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'MSnLib+': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"MSnLib-\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'MSnLib-': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI16+\": avg_func(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI16+': avg_func(\n", + " df_cas[df_cas['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI16-\": avg_func(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI16-': avg_func(\n", + " df_cas[df_cas['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI22+\": avg_func(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI22+': avg_func(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'cfm')\n", " ]\n", " ),\n", - " \"CASMI22-\": avg_func(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M-H]-\"][\n", - " score.replace(\"spectral\", \"cfm\")\n", + " 'CASMI22-': avg_func(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M-H]-'][\n", + " score.replace('spectral', 'cfm')\n", " ]\n", " ),\n", "}\n", "ice_res = {\n", - " \"model\": \"ICEBERG\",\n", - " \"Test+\": avg_func(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"ice\")\n", + " 'model': 'ICEBERG',\n", + " 'Test+': avg_func(\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'ice')\n", " ].fillna(0.0)\n", " ),\n", - " \"MSnLib+\": avg_func(\n", - " df_msnlib_test[df_msnlib_test[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"ice\")\n", + " 'MSnLib+': avg_func(\n", + " df_msnlib_test[df_msnlib_test['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'ice')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI16+\": avg_func(\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"ice\")\n", + " 'CASMI16+': avg_func(\n", + " df_cas[df_cas['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'ice')\n", " ].fillna(0.0)\n", " ),\n", - " \"CASMI22+\": avg_func(\n", - " df_cas22[df_cas22[\"Precursor_type\"] == \"[M+H]+\"][\n", - " score.replace(\"spectral\", \"ice\")\n", + " 'CASMI22+': avg_func(\n", + " df_cas22[df_cas22['Precursor_type'] == '[M+H]+'][\n", + " score.replace('spectral', 'ice')\n", " ].fillna(0.0)\n", " ),\n", "}\n", "\n", "summaryPos = pd.DataFrame([fiora_res, cfm_id, ice_res])\n", - "print(\"Summary test sets - without precursor\")\n", + "print('Summary test sets - without precursor')\n", "summaryPos" ] }, @@ -17375,7 +17375,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.model_params[\"layer_norm\"] = True" + "model.model_params['layer_norm'] = True" ] }, { @@ -17384,7 +17384,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.save(\"../resources/models/fiora_OS_v1.0.0.pt\")" + "model.save('../resources/models/fiora_OS_v1.0.0.pt')" ] }, { @@ -17414,17 +17414,17 @@ " split_text_experimental=False,\n", " verbose: bool = False,\n", "):\n", - " name_tags = [\"Fiora\", \"ICEBERG\", \"CFM-ID\"]\n", - " peak_tags = [\"sim_peaks\", \"ice_peaks\", \"cfm_peaks\"]\n", + " name_tags = ['Fiora', 'ICEBERG', 'CFM-ID']\n", + " peak_tags = ['sim_peaks', 'ice_peaks', 'cfm_peaks']\n", " scores = [\n", - " (\"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\"),\n", - " (\"ice_sqrt_cosine\", \"ice_sqrt_cosine_wo_prec\"),\n", - " (\"cfm_sqrt_cosine\", \"cfm_sqrt_cosine_wo_prec\"),\n", + " ('spectral_sqrt_cosine', 'spectral_sqrt_cosine_wo_prec'),\n", + " ('ice_sqrt_cosine', 'ice_sqrt_cosine_wo_prec'),\n", + " ('cfm_sqrt_cosine', 'cfm_sqrt_cosine_wo_prec'),\n", " ]\n", " peak_colors = [\n", - " \"#0080FF\",\n", - " \"#FF3333\",\n", - " \"#FFCC00\",\n", + " '#0080FF',\n", + " '#FF3333',\n", + " '#FFCC00',\n", " ] # sns.color_palette(\"YlOrBr\", 10)[3]]\n", " spec_height, spec_width = 1.5, 8\n", "\n", @@ -17434,16 +17434,16 @@ " figsize=(spec_width, spec_height * (len(peak_tags) + 2)),\n", " sharex=True,\n", " ) # gridspec_kw={'width_ratios': [1, 3]}\n", - " img = data[\"Metabolite\"].draw(ax=axs[0])\n", - " textstr = f\"Name: {data['Name'] if 'Name' in data.keys() else data['NAME'] if 'NAME' in data.keys() else data['ChallengeName']}\\nPrecursor type: {data['Precursor_type']}\\nCollision energy: {data['CE']:.1f} eV\"\n", + " img = data['Metabolite'].draw(ax=axs[0])\n", + " textstr = f'Name: {data[\"Name\"] if \"Name\" in data.keys() else data[\"NAME\"] if \"NAME\" in data.keys() else data[\"ChallengeName\"]}\\nPrecursor type: {data[\"Precursor_type\"]}\\nCollision energy: {data[\"CE\"]:.1f} eV'\n", " fig.text(\n", " 0.54 + text_offset,\n", " 0.8,\n", " textstr, # transform=axs[0].transAxes,\n", " fontsize=8,\n", - " horizontalalignment=\"left\",\n", - " verticalalignment=\"top\",\n", - " bbox=dict(facecolor=\"white\", alpha=0),\n", + " horizontalalignment='left',\n", + " verticalalignment='top',\n", + " bbox=dict(facecolor='white', alpha=0),\n", " )\n", "\n", " plt.subplots_adjust(hspace=0.12) # (top=0.94, bottom=0.12, right=0.97, left=0.08)\n", @@ -17452,44 +17452,44 @@ " axs[1].text(\n", " spec_text_position_override[0],\n", " spec_text_position_override[1],\n", - " \"Experimental\\nspectrum\"\n", + " 'Experimental\\nspectrum'\n", " if split_text_experimental\n", - " else \"Experimental spectrum\",\n", + " else 'Experimental spectrum',\n", " transform=axs[1].transAxes,\n", " fontsize=11,\n", - " verticalalignment=\"top\",\n", + " verticalalignment='top',\n", " )\n", - " axs[1].set_xlabel(\"\")\n", + " axs[1].set_xlabel('')\n", " for i, tag in enumerate(peak_tags):\n", " try:\n", " ax = axs[i + 2]\n", - " sv.plot_spectrum({\"peaks\": data[tag]}, ax=ax, color=peak_colors[i])\n", + " sv.plot_spectrum({'peaks': data[tag]}, ax=ax, color=peak_colors[i])\n", " # ax.legend(title=name_tags[i], loc=\"upper left\", labels=scores[i])\n", - " textstr = \"\\n\".join(\n", + " textstr = '\\n'.join(\n", " (\n", - " f\"$\\\\bf{{{name_tags[i]}}}$\",\n", - " f\"Cosine: {data[scores[i][0]]:.2f}\",\n", - " f\"w/o prec: {data[scores[i][1]]:.2f}\",\n", + " f'$\\\\bf{{{name_tags[i]}}}$',\n", + " f'Cosine: {data[scores[i][0]]:.2f}',\n", + " f'w/o prec: {data[scores[i][1]]:.2f}',\n", " )\n", " )\n", "\n", - " ax.set_xlabel(\"\")\n", + " ax.set_xlabel('')\n", " ax.text(\n", " spec_text_position_override[0],\n", " spec_text_position_override[1],\n", " textstr,\n", " transform=ax.transAxes,\n", " fontsize=11,\n", - " verticalalignment=\"top\",\n", + " verticalalignment='top',\n", " ) # ,\n", " # bbox=dict(boxstyle='square,pad=0.5', facecolor='white', alpha=0.5))\n", " except:\n", " if verbose:\n", - " print(f\"Could not plot spectrum {data[tag]}\")\n", + " print(f'Could not plot spectrum {data[tag]}')\n", " continue\n", - " axs[-1].set_xlabel(\"m/z\")\n", + " axs[-1].set_xlabel('m/z')\n", "\n", - " sv.set_default_peak_color(\"#212121\")\n", + " sv.set_default_peak_color('#212121')\n", "\n", " return fig, axs" ] @@ -17530,11 +17530,11 @@ "outputs": [], "source": [ "def get_average_spectra(\n", - " df: pd.DataFrame, score: str = \"spectral_sqrt_cosine\", dif: float = 0.05\n", + " df: pd.DataFrame, score: str = 'spectral_sqrt_cosine', dif: float = 0.05\n", "):\n", - " median_cos, median_cos_wo_prec = df[score].median(), df[score + \"_wo_prec\"].median()\n", + " median_cos, median_cos_wo_prec = df[score].median(), df[score + '_wo_prec'].median()\n", " filter_average = (abs(df[score] - median_cos) < dif) & (\n", - " abs(df[score + \"_wo_prec\"] - median_cos_wo_prec) < dif\n", + " abs(df[score + '_wo_prec'] - median_cos_wo_prec) < dif\n", " )\n", "\n", " return df[filter_average]" @@ -17546,12 +17546,12 @@ "metadata": {}, "outputs": [], "source": [ - "df_a = get_average_spectra(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], dif=0.05)\n", - "df_a = get_average_spectra(df_a, score=\"ice_sqrt_cosine\", dif=0.05)\n", - "df_a = get_average_spectra(df_a, score=\"cfm_sqrt_cosine\", dif=0.05)\n", + "df_a = get_average_spectra(df_test[df_test['Precursor_type'] == '[M+H]+'], dif=0.05)\n", + "df_a = get_average_spectra(df_a, score='ice_sqrt_cosine', dif=0.05)\n", + "df_a = get_average_spectra(df_a, score='cfm_sqrt_cosine', dif=0.05)\n", "df_a = df_a[\n", - " (df_a[\"spectral_sqrt_cosine_wo_prec\"] > df_a[\"ice_sqrt_cosine_wo_prec\"])\n", - " & (df_a[\"spectral_sqrt_cosine_wo_prec\"] > df_a[\"cfm_sqrt_cosine_wo_prec\"])\n", + " (df_a['spectral_sqrt_cosine_wo_prec'] > df_a['ice_sqrt_cosine_wo_prec'])\n", + " & (df_a['spectral_sqrt_cosine_wo_prec'] > df_a['cfm_sqrt_cosine_wo_prec'])\n", "]" ] }, @@ -17617,7 +17617,7 @@ " fig, axs = stacked_spectrum(\n", " df_a.loc[i], text_offset=0.02, split_text_experimental=True, verbose=True\n", " )\n", - " mol = df_a.loc[i][\"Metabolite\"].MOL\n", + " mol = df_a.loc[i]['Metabolite'].MOL\n", " # fig.savefig(f\"{home}/images/paper/stacked_avg{i}.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/stacked_avg{i}.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/stacked_avg{i}.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -17674,20 +17674,20 @@ "high_score_threshold = 0.7\n", "low_score_threshold = 0.5\n", "\n", - "df_ex = df_msnlib_test[df_msnlib_test[\"SPECTYPE\"] == \"SAME_ENERGY\"]\n", + "df_ex = df_msnlib_test[df_msnlib_test['SPECTYPE'] == 'SAME_ENERGY']\n", "\n", "\n", "high_score_30 = df_ex[\n", - " (df_ex[\"CE\"] == 30) & (df_ex[\"spectral_sqrt_cosine\"] >= high_score_threshold)\n", + " (df_ex['CE'] == 30) & (df_ex['spectral_sqrt_cosine'] >= high_score_threshold)\n", "]\n", "low_score_60 = df_ex[\n", - " (df_ex[\"CE\"] == 60) & (df_ex[\"spectral_sqrt_cosine\"] <= low_score_threshold)\n", + " (df_ex['CE'] == 60) & (df_ex['spectral_sqrt_cosine'] <= low_score_threshold)\n", "]\n", "\n", "common_group_ids = pd.merge(\n", - " high_score_30[[\"group_id\"]], low_score_60[[\"group_id\"]], on=\"group_id\"\n", + " high_score_30[['group_id']], low_score_60[['group_id']], on='group_id'\n", ")\n", - "matching_ids = common_group_ids[\"group_id\"].unique()" + "matching_ids = common_group_ids['group_id'].unique()" ] }, { @@ -17716,7 +17716,7 @@ "metadata": {}, "outputs": [], "source": [ - "examples = df_ex[df_ex[\"group_id\"] == matching_ids[1]]" + "examples = df_ex[df_ex['group_id'] == matching_ids[1]]" ] }, { @@ -17762,21 +17762,21 @@ " fig, axs = stacked_spectrum(\n", " examples.loc[i], text_offset=0.02, split_text_experimental=True, verbose=True\n", " )\n", - " ce = data[\"CE\"]\n", - " mol = data[\"Metabolite\"].MOL\n", - " ce = data[\"CE\"]\n", + " ce = data['CE']\n", + " mol = data['Metabolite'].MOL\n", + " ce = data['CE']\n", " spec_df = {\n", - " f\"Experimental m/z\": data[\"peaks\"][\"mz\"],\n", - " f\"Experimental intensity\": data[\"peaks\"][\"intensity\"],\n", - " f\"Fiora m/z\": data[\"sim_peaks\"][\"mz\"],\n", - " f\"Fiora intensity\": data[\"sim_peaks\"][\"intensity\"],\n", - " f\"ICEBERG m/z\": data[\"ice_peaks\"][\"mz\"],\n", - " f\"ICEBERG intensity\": data[\"ice_peaks\"][\"intensity\"],\n", - " f\"CFM-ID m/z\": data[\"cfm_peaks\"][\"mz\"],\n", - " f\"CFM-ID intensity\": data[\"cfm_peaks\"][\"intensity\"],\n", + " 'Experimental m/z': data['peaks']['mz'],\n", + " 'Experimental intensity': data['peaks']['intensity'],\n", + " 'Fiora m/z': data['sim_peaks']['mz'],\n", + " 'Fiora intensity': data['sim_peaks']['intensity'],\n", + " 'ICEBERG m/z': data['ice_peaks']['mz'],\n", + " 'ICEBERG intensity': data['ice_peaks']['intensity'],\n", + " 'CFM-ID m/z': data['cfm_peaks']['mz'],\n", + " 'CFM-ID intensity': data['cfm_peaks']['intensity'],\n", " }\n", - " spec_df = pd.DataFrame.from_dict(spec_df, orient=\"index\").transpose()\n", - " spec_df = spec_df.fillna(\"-\")\n", + " spec_df = pd.DataFrame.from_dict(spec_df, orient='index').transpose()\n", + " spec_df = spec_df.fillna('-')\n", " # spec_df.to_excel(f\"{home}/images/paper/SF_{ce}.xlsx\")\n", " # fig.savefig(f\"{home}/images/paper/stacked_ce{ce}.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/stacked_ce{ce}.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -17817,16 +17817,16 @@ ], "source": [ "if has_m_plus:\n", - " print(\"Conducting performance analysis for [M]+ and [M]- precursors\")\n", + " print('Conducting performance analysis for [M]+ and [M]- precursors')\n", "\n", - " df_m = df_msnlib_test[df_msnlib_test[\"Precursor_type\"].isin([\"[M]+\", \"[M]-\"])]\n", - " df_m = df_m[df_m[\"SPECTYPE\"] != \"ALL_MSN_TO_PSEUDO_MS2\"]\n", + " df_m = df_msnlib_test[df_msnlib_test['Precursor_type'].isin(['[M]+', '[M]-'])]\n", + " df_m = df_m[df_m['SPECTYPE'] != 'ALL_MSN_TO_PSEUDO_MS2']\n", " df_not_m = df_msnlib_test[\n", - " df_msnlib_test[\"Precursor_type\"].isin([\"[M+H]+\", \"[M-H]-\"])\n", + " df_msnlib_test['Precursor_type'].isin(['[M+H]+', '[M-H]-'])\n", " ]\n", "\n", - " print(df_m.groupby(\"Precursor_type\")[\"spectral_sqrt_cosine\"].median())\n", - " print(df_m.groupby(\"Precursor_type\")[\"spectral_sqrt_cosine_wo_prec\"].median())" + " print(df_m.groupby('Precursor_type')['spectral_sqrt_cosine'].median())\n", + " print(df_m.groupby('Precursor_type')['spectral_sqrt_cosine_wo_prec'].median())" ] }, { @@ -17847,7 +17847,7 @@ ], "source": [ "if has_m_plus:\n", - " print(df_m.groupby(\"Precursor_type\")[\"group_id\"].nunique())" + " print(df_m.groupby('Precursor_type')['group_id'].nunique())" ] }, { @@ -17869,7 +17869,7 @@ } ], "source": [ - "print(\"Training:\", df_msnlib_train.groupby(\"Precursor_type\")[\"group_id\"].nunique())" + "print('Training:', df_msnlib_train.groupby('Precursor_type')['group_id'].nunique())" ] }, { @@ -17896,17 +17896,17 @@ "if has_m_plus:\n", " iii = []\n", " for i, data in df_not_m.iterrows():\n", - " metabolite = data[\"Metabolite\"]\n", + " metabolite = data['Metabolite']\n", " for i, data2 in df_m.iterrows():\n", - " other_metabolite = data2[\"Metabolite\"]\n", + " other_metabolite = data2['Metabolite']\n", " if metabolite == other_metabolite:\n", - " if data[\"CE\"] == data2[\"CE\"] and (\n", - " (\"+\" in data[\"Precursor_type\"]) == (\"+\" in data2[\"Precursor_type\"])\n", + " if data['CE'] == data2['CE'] and (\n", + " ('+' in data['Precursor_type']) == ('+' in data2['Precursor_type'])\n", " ):\n", " print(\n", - " \"cos:\",\n", - " data[\"spectral_sqrt_cosine\"],\n", - " data2[\"spectral_sqrt_cosine\"],\n", + " 'cos:',\n", + " data['spectral_sqrt_cosine'],\n", + " data2['spectral_sqrt_cosine'],\n", " )" ] }, @@ -17959,31 +17959,31 @@ " data = df_m.iloc[0]\n", "\n", " for i, data in df_m.tail(2).iterrows():\n", - " cosine = data[\"spectral_sqrt_cosine\"]\n", - " cosine_wo = data[\"spectral_sqrt_cosine_wo_prec\"]\n", - " name = data[\"NAME\"]\n", - " prec = data[\"Precursor_type\"]\n", + " cosine = data['spectral_sqrt_cosine']\n", + " cosine_wo = data['spectral_sqrt_cosine_wo_prec']\n", + " name = data['NAME']\n", + " prec = data['Precursor_type']\n", "\n", - " print(f\"{prec} ({i}): cosine {cosine:0.2} / {cosine_wo:0.2}\")\n", - " print(max(data[\"peaks\"][\"mz\"]))\n", - " print(max(data[\"sim_peaks\"][\"mz\"]))\n", + " print(f'{prec} ({i}): cosine {cosine:0.2} / {cosine_wo:0.2}')\n", + " print(max(data['peaks']['mz']))\n", + " print(max(data['sim_peaks']['mz']))\n", "\n", " fig, axs = plt.subplots(\n", " 1,\n", " 2,\n", " figsize=(12.8, 4.2),\n", - " gridspec_kw={\"width_ratios\": [1, 3]},\n", + " gridspec_kw={'width_ratios': [1, 3]},\n", " sharey=False,\n", " )\n", - " img = data[\"Metabolite\"].draw(ax=axs[0])\n", + " img = data['Metabolite'].draw(ax=axs[0])\n", "\n", " axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", " )\n", - " axs[0].set_title(data[\"NAME\"])\n", + " axs[0].set_title(data['NAME'])\n", "\n", " ax = sv.plot_spectrum(\n", - " data, {\"peaks\": data[\"sim_peaks\"]}, ax=axs[1], highlight_matches=False\n", + " data, {'peaks': data['sim_peaks']}, ax=axs[1], highlight_matches=False\n", " )\n", " # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/example_mirror_id{i}.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -18008,7 +18008,7 @@ } ], "source": [ - "sns.boxplot(df_msnlib_test, y=\"spectral_sqrt_cosine_wo_prec\", x=\"origin\")\n", + "sns.boxplot(df_msnlib_test, y='spectral_sqrt_cosine_wo_prec', x='origin')\n", "plt.show()" ] }, @@ -18030,7 +18030,7 @@ } ], "source": [ - "raise KeyboardInterrupt(\"Stop\")" + "raise KeyboardInterrupt('Stop')" ] }, { @@ -18159,7 +18159,7 @@ } ], "source": [ - "df_test.groupby(\"Precursor_type\").agg(num=(\"group_id\", lambda x: len(x.unique())))" + "df_test.groupby('Precursor_type').agg(num=('group_id', lambda x: len(x.unique())))" ] }, { @@ -18236,8 +18236,8 @@ "metadata": {}, "outputs": [], "source": [ - "import requests\n", "import pubchempy\n", + "import requests\n", "\n", "\n", "def retrieve_first_k_compounds_with_mass(\n", @@ -18246,27 +18246,27 @@ " # Construct the URL for the molecular weight range\n", " lower_bound = mol_weight - tolerance\n", " upper_bound = mol_weight + tolerance\n", - " url = f\"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/exact_mass/range/{lower_bound}/{upper_bound}/cids/JSON\"\n", + " url = f'https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/exact_mass/range/{lower_bound}/{upper_bound}/cids/JSON'\n", " # TODO test molecular_weight\n", "\n", " # Request compound IDs from PubChem\n", " response = requests.get(url)\n", " if response.ok:\n", " js = response.json()\n", - " if \"IdentifierList\" in js and \"CID\" in js[\"IdentifierList\"]:\n", + " if 'IdentifierList' in js and 'CID' in js['IdentifierList']:\n", " # if too few entries -> double tolerance\n", - " if len(js[\"IdentifierList\"][\"CID\"]) < k:\n", + " if len(js['IdentifierList']['CID']) < k:\n", " return retrieve_first_k_compounds_with_mass(\n", " mol_weight, tolerance * 2, k\n", " )\n", " # Retrieve the first k compounds using PubChemPy\n", - " compound_list = pubchempy.get_compounds(js[\"IdentifierList\"][\"CID\"][:k])\n", + " compound_list = pubchempy.get_compounds(js['IdentifierList']['CID'][:k])\n", " return [c.canonical_smiles for c in compound_list]\n", " else:\n", - " print(\"No compounds found in the given range.\")\n", + " print('No compounds found in the given range.')\n", " return retrieve_first_k_compounds_with_mass(mol_weight, tolerance * 2, k)\n", " else:\n", - " print(f\"Error: {response.status_code} - {response.text}\")\n", + " print(f'Error: {response.status_code} - {response.text}')\n", " return []" ] }, @@ -18276,7 +18276,7 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import List, Dict\n", + "from typing import Dict, List\n", "\n", "\n", "# Handle invalid SMILES during Metabolite creation\n", @@ -18294,84 +18294,84 @@ " ce_steps: List[int] = [],\n", " k: int = 10,\n", ") -> Dict[str, float]:\n", - " candidates_df = pd.DataFrame({\"SMILES\": [c for c in candidates if \".\" not in c]})\n", - " candidates_df[\"peaks\"] = [exp_peaks for _ in range(candidates_df.shape[0])]\n", + " candidates_df = pd.DataFrame({'SMILES': [c for c in candidates if '.' not in c]})\n", + " candidates_df['peaks'] = [exp_peaks for _ in range(candidates_df.shape[0])]\n", "\n", " # Perform Metabolite fragmentation workflow\n", - " candidates_df[\"Metabolite\"] = candidates_df[\"SMILES\"].apply(\n", + " candidates_df['Metabolite'] = candidates_df['SMILES'].apply(\n", " safe_metabolite_creation\n", " )\n", - " candidates_df.dropna(subset=[\"Metabolite\"], inplace=True)\n", - " eq_metabolite_mask = candidates_df[\"Metabolite\"].apply(lambda m: m == metabolite)\n", + " candidates_df.dropna(subset=['Metabolite'], inplace=True)\n", + " eq_metabolite_mask = candidates_df['Metabolite'].apply(lambda m: m == metabolite)\n", " candidates_df = candidates_df[~eq_metabolite_mask]\n", - " candidates_df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", - " candidates_df[\"Metabolite\"].apply(\n", + " candidates_df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", + " candidates_df['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", " )\n", " metadata = metabolite.metadata\n", "\n", " if not ce_steps:\n", - " candidates_df[\"Metabolite\"].apply(\n", + " candidates_df['Metabolite'].apply(\n", " lambda x: x.add_metadata(metadata, covariate_encoder, rt_encoder)\n", " )\n", - " candidates_df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + " candidates_df['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", " candidates_df = fiora.simulate_all(candidates_df, model)\n", " else:\n", " for i, ce in enumerate(ce_steps):\n", - " metadata.update({\"collision_energy\": ce})\n", - " candidates_df[\"Metabolite\"].apply(\n", + " metadata.update({'collision_energy': ce})\n", + " candidates_df['Metabolite'].apply(\n", " lambda x: x.add_metadata(metadata, covariate_encoder, rt_encoder)\n", " )\n", - " candidates_df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", - " candidates_df = fiora.simulate_all(candidates_df, model, suffix=f\"_{i + 1}\")\n", + " candidates_df['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", + " candidates_df = fiora.simulate_all(candidates_df, model, suffix=f'_{i + 1}')\n", " if len(ce_steps) != 3:\n", " raise NotImplementedError(\n", - " \"Only three collision energy steps are implemented\"\n", + " 'Only three collision energy steps are implemented'\n", " )\n", - " candidates_df[\"merged_peaks\"] = candidates_df.apply(\n", + " candidates_df['merged_peaks'] = candidates_df.apply(\n", " lambda x: merge_annotated_spectrum(\n", - " merge_annotated_spectrum(x[\"sim_peaks_1\"], x[\"sim_peaks_2\"]),\n", - " x[\"sim_peaks_3\"],\n", + " merge_annotated_spectrum(x['sim_peaks_1'], x['sim_peaks_2']),\n", + " x['sim_peaks_3'],\n", " ),\n", " axis=1,\n", " )\n", - " candidates_df[\"spectral_sqrt_cosine\"] = candidates_df.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " candidates_df['spectral_sqrt_cosine'] = candidates_df.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks'], transform=np.sqrt),\n", " axis=1,\n", " )\n", - " candidates_df[\"spectral_sqrt_cosine_wo_prec\"] = candidates_df.apply(\n", + " candidates_df['spectral_sqrt_cosine_wo_prec'] = candidates_df.apply(\n", " lambda x: spectral_cosine(\n", - " x[\"peaks\"],\n", - " x[\"merged_peaks\"],\n", + " x['peaks'],\n", + " x['merged_peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " x[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=x['Metabolite'].get_theoretical_precursor_mz(\n", + " x['Metabolite'].metadata['precursor_mode']\n", " ),\n", " ),\n", " axis=1,\n", " )\n", - " candidates_df[\"spectral_sqrt_cosine_avg\"] = (\n", - " candidates_df[\"spectral_sqrt_cosine\"]\n", - " + candidates_df[\"spectral_sqrt_cosine_wo_prec\"]\n", + " candidates_df['spectral_sqrt_cosine_avg'] = (\n", + " candidates_df['spectral_sqrt_cosine']\n", + " + candidates_df['spectral_sqrt_cosine_wo_prec']\n", " ) / 2.0\n", "\n", " score_tags = [\n", - " \"spectral_sqrt_cosine\",\n", - " \"spectral_sqrt_cosine_wo_prec\",\n", - " \"spectral_sqrt_cosine_avg\",\n", + " 'spectral_sqrt_cosine',\n", + " 'spectral_sqrt_cosine_wo_prec',\n", + " 'spectral_sqrt_cosine_avg',\n", " ]\n", " scores = {\n", " tag: candidates_df[tag].sort_values(ascending=False).head(k).values\n", " for tag in score_tags\n", " }\n", - " scores[\"eq_c\"] = list(np.where(eq_metabolite_mask)[0])\n", + " scores['eq_c'] = list(np.where(eq_metabolite_mask)[0])\n", "\n", " return scores\n", "\n", "\n", - "def get_k(row, scoring_func: str = \"spectral_sqrt_cosine\"):\n", + "def get_k(row, scoring_func: str = 'spectral_sqrt_cosine'):\n", " score = row[scoring_func]\n", - " candidate_scores = row[\"candidate_scores\"][scoring_func]\n", + " candidate_scores = row['candidate_scores'][scoring_func]\n", " epsilon = 0.0 # optional: add epsilon = 0.00001 # for Indistinguishable compounds / rounding deviation\n", " return sum([(score - c_score) <= epsilon for c_score in candidate_scores]) + 1" ] @@ -18418,9 +18418,9 @@ "import ast\n", "\n", "# # Load back the candidates into df_test\n", - "df_test[\"candidates\"] = pd.read_csv(\n", - " f\"{home}/data/metabolites/benchmarking/df_test_candidates.csv\", index_col=0\n", - ")[\"candidates\"].apply(ast.literal_eval)" + "df_test['candidates'] = pd.read_csv(\n", + " f'{home}/data/metabolites/benchmarking/df_test_candidates.csv', index_col=0\n", + ")['candidates'].apply(ast.literal_eval)" ] }, { @@ -18429,8 +18429,8 @@ "metadata": {}, "outputs": [], "source": [ - "df_test[\"candidate_scores\"] = df_test.apply(\n", - " lambda x: get_top_k_scores(x[\"candidates\"], x[\"Metabolite\"], x[\"peaks\"], k=10),\n", + "df_test['candidate_scores'] = df_test.apply(\n", + " lambda x: get_top_k_scores(x['candidates'], x['Metabolite'], x['peaks'], k=10),\n", " axis=1,\n", ")" ] @@ -18441,12 +18441,12 @@ "metadata": {}, "outputs": [], "source": [ - "df_test[\"k\"] = df_test.apply(get_k, axis=1)\n", - "df_test[\"k_wo_prec\"] = df_test.apply(\n", - " lambda x: get_k(x, scoring_func=\"spectral_sqrt_cosine_wo_prec\"), axis=1\n", + "df_test['k'] = df_test.apply(get_k, axis=1)\n", + "df_test['k_wo_prec'] = df_test.apply(\n", + " lambda x: get_k(x, scoring_func='spectral_sqrt_cosine_wo_prec'), axis=1\n", ")\n", - "df_test[\"k_avg\"] = df_test.apply(\n", - " lambda x: get_k(x, scoring_func=\"spectral_sqrt_cosine_avg\"), axis=1\n", + "df_test['k_avg'] = df_test.apply(\n", + " lambda x: get_k(x, scoring_func='spectral_sqrt_cosine_avg'), axis=1\n", ")" ] }, @@ -18467,7 +18467,7 @@ } ], "source": [ - "sns.histplot(df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"], x=\"k\", binwidth=1)\n", + "sns.histplot(df_test[df_test['Precursor_type'] == '[M+H]+'], x='k', binwidth=1)\n", "plt.show()" ] }, @@ -18477,8 +18477,8 @@ "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", "import seaborn as sns\n", "\n", "\n", @@ -18490,7 +18490,7 @@ " max_rank: int = 11,\n", " ylim: (float, float) = (0, 1),\n", " ratio=(8, 6),\n", - " title=\"\",\n", + " title='',\n", "):\n", " fig, ax = plt.subplots(figsize=ratio)\n", " # fig.set_size_inches(5, 5)\n", @@ -18512,33 +18512,33 @@ "\n", " # Create a DataFrame for plotting\n", " cumulative_df = pd.DataFrame(\n", - " {\"Rank (k)\": range(1, max_rank), \"Fraction\": fractions.values[:-1]}\n", + " {'Rank (k)': range(1, max_rank), 'Fraction': fractions.values[:-1]}\n", " )\n", "\n", " # Plot\n", " sns.pointplot(\n", " data=cumulative_df,\n", - " x=\"Rank (k)\",\n", - " y=\"Fraction\",\n", + " x='Rank (k)',\n", + " y='Fraction',\n", " label=labels[i] if len(labels) > 0 else k_tag,\n", " linestyle=(0, (1, 2.5))\n", " if (i + 1) % 3 == 0\n", - " else \"-\", # Alternate linestyles\n", - " markers=\"x\" if (i + 1) % 3 == 0 else \"o\",\n", - " color=f\"C{i}\"\n", + " else '-', # Alternate linestyles\n", + " markers='x' if (i + 1) % 3 == 0 else 'o',\n", + " color=f'C{i}'\n", " if len(colors) == 0\n", " else colors[i], # Use different colors for each tag\n", " linewidth=2.5,\n", " )\n", "\n", - " plt.rc(\"axes\", labelsize=14)\n", - " plt.rc(\"legend\", fontsize=14)\n", - " ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + " plt.rc('axes', labelsize=14)\n", + " plt.rc('legend', fontsize=14)\n", + " ax.tick_params(axis='both', which='major', labelsize=13)\n", "\n", " # ax.set_aspect(1, adjustable=\"box\")\n", " # Final plot settings\n", - " plt.xlabel(\"Rank (k)\")\n", - " plt.ylabel(\"Recall\")\n", + " plt.xlabel('Rank (k)')\n", + " plt.ylabel('Recall')\n", " plt.ylim(ylim)\n", " plt.legend(title=title) # , labels=labels if len(labels) > 0 else k_tags)\n", " plt.grid(True)\n", @@ -18567,14 +18567,14 @@ "\n", "set_light_theme()\n", "fig, fig_data = plot_top_k_performance(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"],\n", - " [\"k\", \"k_wo_prec\", \"k_avg\"],\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'],\n", + " ['k', 'k_wo_prec', 'k_avg'],\n", " labels=[\n", - " \"Cosine similarity\",\n", - " \"Cosine similarity w/o precursor\",\n", - " \"Average cosine similarity\",\n", + " 'Cosine similarity',\n", + " 'Cosine similarity w/o precursor',\n", + " 'Average cosine similarity',\n", " ],\n", - " colors=[sns.color_palette(\"Paired\")[1], sns.color_palette(\"Paired\")[0], \"red\"],\n", + " colors=[sns.color_palette('Paired')[1], sns.color_palette('Paired')[0], 'red'],\n", " ratio=(7.2, 6),\n", " ylim=(0.35, 1),\n", ")\n", @@ -18594,105 +18594,105 @@ "if False:\n", " # Step 1: Drop duplicate group_id entries\n", " df_unique = (\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"]\n", - " .drop_duplicates(subset=[\"group_id\"])\n", + " df_test[df_test['Precursor_type'] == '[M+H]+']\n", + " .drop_duplicates(subset=['group_id'])\n", " .copy()\n", " )\n", "\n", " # Step 2: Explode the candidates list into individual rows\n", - " df_exploded = df_unique.explode(\"candidates\", ignore_index=True)\n", + " df_exploded = df_unique.explode('candidates', ignore_index=True)\n", "\n", " # Step 3: Ensure Metabolite creation\n", - " df_exploded[\"Metabolite\"] = df_exploded[\"candidates\"].apply(\n", + " df_exploded['Metabolite'] = df_exploded['candidates'].apply(\n", " safe_metabolite_creation\n", " )\n", - " assert df_exploded.groupby(\"group_id\").size().eq(49).all(), (\n", - " \"Some group_id is missing candidates\"\n", + " assert df_exploded.groupby('group_id').size().eq(49).all(), (\n", + " 'Some group_id is missing candidates'\n", " )\n", "\n", " # Step 4: Add idx with c_1, c_2 suffixes\n", - " df_exploded[\"idx\"] = (\n", - " \"spec\"\n", - " + df_exploded[\"group_id\"].astype(int).astype(str)\n", - " + \"_c_\"\n", - " + (df_exploded.groupby(\"group_id\").cumcount() + 1).astype(str)\n", + " df_exploded['idx'] = (\n", + " 'spec'\n", + " + df_exploded['group_id'].astype(int).astype(str)\n", + " + '_c_'\n", + " + (df_exploded.groupby('group_id').cumcount() + 1).astype(str)\n", " )\n", "\n", " # Step 5: Calculate formula and InChIKey, add empty strings for None values\n", - " df_exploded[\"formula\"] = df_exploded[\"Metabolite\"].apply(\n", - " lambda x: x.Formula if x else \"\"\n", + " df_exploded['formula'] = df_exploded['Metabolite'].apply(\n", + " lambda x: x.Formula if x else ''\n", " )\n", - " df_exploded[\"InChIKey\"] = df_exploded[\"Metabolite\"].apply(\n", - " lambda x: x.InChIKey if x else \"\"\n", + " df_exploded['InChIKey'] = df_exploded['Metabolite'].apply(\n", + " lambda x: x.InChIKey if x else ''\n", " )\n", "\n", " # Step 6: Map column names to ICEBERG format\n", " label_map = {\n", - " \"idx\": \"spec\",\n", - " \"Name\": \"name\",\n", - " \"Precursor_type\": \"ionization\",\n", - " \"candidates\": \"smiles\",\n", - " \"formula\": \"formula\",\n", - " \"InChIKey\": \"inchikey\",\n", + " 'idx': 'spec',\n", + " 'Name': 'name',\n", + " 'Precursor_type': 'ionization',\n", + " 'candidates': 'smiles',\n", + " 'formula': 'formula',\n", + " 'InChIKey': 'inchikey',\n", " }\n", "\n", " # Step 7: Save in ICEBERG-compatible format\n", " df_exploded.rename(columns=label_map)[\n", - " [\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]\n", + " ['dataset', 'spec', 'name', 'ionization', 'formula', 'smiles', 'inchikey']\n", " ].to_csv(\n", - " f\"{home}/data/metabolites/ms-pred/df_test_candidates.tsv\", index=False, sep=\"\\t\"\n", + " f'{home}/data/metabolites/ms-pred/df_test_candidates.tsv', index=False, sep='\\t'\n", " )\n", "\n", "if False:\n", " # Step 1: Drop duplicate ChallengeName entries\n", " df_unique = (\n", - " df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"]\n", - " .drop_duplicates(subset=[\"ChallengeName\"])\n", + " df_cas[df_cas['Precursor_type'] == '[M+H]+']\n", + " .drop_duplicates(subset=['ChallengeName'])\n", " .copy()\n", " )\n", "\n", " # Step 2: Explode the candidates list into individual rows\n", - " df_exploded = df_unique.explode(\"candidates\", ignore_index=True)\n", + " df_exploded = df_unique.explode('candidates', ignore_index=True)\n", "\n", " # Step 3: Ensure Metabolite creation\n", - " df_exploded[\"Metabolite\"] = df_exploded[\"candidates\"].apply(\n", + " df_exploded['Metabolite'] = df_exploded['candidates'].apply(\n", " safe_metabolite_creation\n", " )\n", - " assert df_exploded.groupby(\"ChallengeName\").size().eq(49).all(), (\n", - " \"Some ChallengeName is missing candidates\"\n", + " assert df_exploded.groupby('ChallengeName').size().eq(49).all(), (\n", + " 'Some ChallengeName is missing candidates'\n", " )\n", "\n", " # Step 4: Add idx with c_1, c_2 suffixes\n", - " df_exploded[\"idx\"] = (\n", - " \"spec\"\n", - " + df_exploded[\"ChallengeName\"].astype(str)\n", - " + \"_c_\"\n", - " + (df_exploded.groupby(\"ChallengeName\").cumcount() + 1).astype(str)\n", + " df_exploded['idx'] = (\n", + " 'spec'\n", + " + df_exploded['ChallengeName'].astype(str)\n", + " + '_c_'\n", + " + (df_exploded.groupby('ChallengeName').cumcount() + 1).astype(str)\n", " )\n", "\n", " # Step 5: Calculate formula and InChIKey, add empty strings for None values\n", - " df_exploded[\"formula\"] = df_exploded[\"Metabolite\"].apply(\n", - " lambda x: x.Formula if x else \"\"\n", + " df_exploded['formula'] = df_exploded['Metabolite'].apply(\n", + " lambda x: x.Formula if x else ''\n", " )\n", - " df_exploded[\"InChIKey\"] = df_exploded[\"Metabolite\"].apply(\n", - " lambda x: x.InChIKey if x else \"\"\n", + " df_exploded['InChIKey'] = df_exploded['Metabolite'].apply(\n", + " lambda x: x.InChIKey if x else ''\n", " )\n", "\n", " # Step 6: Map column names to ICEBERG format\n", " label_map = {\n", - " \"idx\": \"spec\",\n", - " \"ChallengeName\": \"name\",\n", - " \"Precursor_type\": \"ionization\",\n", - " \"candidates\": \"smiles\",\n", - " \"formula\": \"formula\",\n", - " \"InChIKey\": \"inchikey\",\n", + " 'idx': 'spec',\n", + " 'ChallengeName': 'name',\n", + " 'Precursor_type': 'ionization',\n", + " 'candidates': 'smiles',\n", + " 'formula': 'formula',\n", + " 'InChIKey': 'inchikey',\n", " }\n", "\n", " # Step 7: Save in ICEBERG-compatible format\n", " df_exploded.rename(columns=label_map)[\n", - " [\"dataset\", \"spec\", \"name\", \"ionization\", \"formula\", \"smiles\", \"inchikey\"]\n", + " ['dataset', 'spec', 'name', 'ionization', 'formula', 'smiles', 'inchikey']\n", " ].to_csv(\n", - " f\"{home}/data/metabolites/ms-pred/df_cas_candidates.tsv\", index=False, sep=\"\\t\"\n", + " f'{home}/data/metabolites/ms-pred/df_cas_candidates.tsv', index=False, sep='\\t'\n", " )" ] }, @@ -18702,45 +18702,45 @@ "metadata": {}, "outputs": [], "source": [ - "iceberg_dir = f\"{home}/repos/ms-pred/results/test_out_recovery/df_test_candidates/tree_preds_inten\"\n", + "iceberg_dir = f'{home}/repos/ms-pred/results/test_out_recovery/df_test_candidates/tree_preds_inten'\n", "df_ice = mspredReader.read(iceberg_dir)\n", - "df_ice[\"group_id\"] = df_ice[\"name\"].str.extract(r\"spec(\\d+)_\").astype(int)\n", - "df_ice[\"c\"] = df_ice[\"name\"].str.extract(r\"_c_(\\d+)\").astype(int)\n", + "df_ice['group_id'] = df_ice['name'].str.extract(r'spec(\\d+)_').astype(int)\n", + "df_ice['c'] = df_ice['name'].str.extract(r'_c_(\\d+)').astype(int)\n", "\n", - "for i, row in df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"].iterrows():\n", - " spec = row[\"peaks\"]\n", - " group_id = int(row[\"group_id\"])\n", - " eq_c = row[\"candidate_scores\"][\"eq_c\"]\n", + "for i, row in df_test[df_test['Precursor_type'] == '[M+H]+'].iterrows():\n", + " spec = row['peaks']\n", + " group_id = int(row['group_id'])\n", + " eq_c = row['candidate_scores']['eq_c']\n", "\n", - " df_candidate_matches = df_ice[df_ice[\"group_id\"] == group_id]\n", + " df_candidate_matches = df_ice[df_ice['group_id'] == group_id]\n", " df_candidate_matches = df_candidate_matches[\n", - " ~df_candidate_matches[\"c\"].isin([c + 1 for c in eq_c])\n", + " ~df_candidate_matches['c'].isin([c + 1 for c in eq_c])\n", " ]\n", "\n", " ssc = list(\n", - " df_candidate_matches[\"peaks\"].apply(\n", + " df_candidate_matches['peaks'].apply(\n", " lambda c_spec: spectral_cosine(c_spec, spec, transform=np.sqrt)\n", " )\n", " )\n", " sscwop = list(\n", - " df_candidate_matches[\"peaks\"].apply(\n", + " df_candidate_matches['peaks'].apply(\n", " lambda c_spec: spectral_cosine(\n", " c_spec,\n", " spec,\n", " transform=np.sqrt,\n", - " remove_mz=row[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " ion_type=row[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=row['Metabolite'].get_theoretical_precursor_mz(\n", + " ion_type=row['Metabolite'].metadata['precursor_mode']\n", " ),\n", " )\n", " )\n", " )\n", " sscavg = [(ssc[i] + sscwop[i]) / 2.0 for i in range(len(ssc))]\n", "\n", - " row[\"candidate_scores\"][\"ice_sqrt_cosine\"] = sorted(ssc, reverse=True)[:10]\n", - " row[\"candidate_scores\"][\"ice_sqrt_cosine_wo_prec\"] = sorted(sscwop, reverse=True)[\n", + " row['candidate_scores']['ice_sqrt_cosine'] = sorted(ssc, reverse=True)[:10]\n", + " row['candidate_scores']['ice_sqrt_cosine_wo_prec'] = sorted(sscwop, reverse=True)[\n", " :10\n", " ]\n", - " row[\"candidate_scores\"][\"ice_sqrt_cosine_avg\"] = sorted(sscavg, reverse=True)[:10]" + " row['candidate_scores']['ice_sqrt_cosine_avg'] = sorted(sscavg, reverse=True)[:10]" ] }, { @@ -18777,12 +18777,12 @@ "metadata": {}, "outputs": [], "source": [ - "pos_mask = df_test[\"Precursor_type\"] == \"[M+H]+\"\n", - "df_test.loc[pos_mask, \"k_ice\"] = df_test.loc[pos_mask].apply(\n", - " lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine\"), axis=1\n", + "pos_mask = df_test['Precursor_type'] == '[M+H]+'\n", + "df_test.loc[pos_mask, 'k_ice'] = df_test.loc[pos_mask].apply(\n", + " lambda x: get_k(x, scoring_func='ice_sqrt_cosine'), axis=1\n", ")\n", - "df_test.loc[pos_mask, \"k_wo_prec_ice\"] = df_test.loc[pos_mask].apply(\n", - " lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine_wo_prec\"), axis=1\n", + "df_test.loc[pos_mask, 'k_wo_prec_ice'] = df_test.loc[pos_mask].apply(\n", + " lambda x: get_k(x, scoring_func='ice_sqrt_cosine_wo_prec'), axis=1\n", ")\n", "# df_test.loc[pos_mask, \"k_avg_ice\"] = df_test.loc[pos_mask].apply(lambda x: get_k(x, scoring_func=\"ice_sqrt_cosine_avg\"), axis=1)" ] @@ -18860,9 +18860,9 @@ ], "source": [ "fig, fig_data = plot_top_k_performance(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"],\n", - " [\"k\", \"k_ice\"],\n", - " labels=[\"Fiora\", \"ICEBERG\"],\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'],\n", + " ['k', 'k_ice'],\n", + " labels=['Fiora', 'ICEBERG'],\n", " ylim=(0.2, 1),\n", " colors=[lightblue_hex, lightpink_hex],\n", " ratio=(7.2, 6),\n", @@ -18892,9 +18892,9 @@ ], "source": [ "fig, fig_data = plot_top_k_performance(\n", - " df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"],\n", - " [\"k_wo_prec\", \"k_wo_prec_ice\"],\n", - " labels=[\"Fiora\", \"ICEBERG\"],\n", + " df_test[df_test['Precursor_type'] == '[M+H]+'],\n", + " ['k_wo_prec', 'k_wo_prec_ice'],\n", + " labels=['Fiora', 'ICEBERG'],\n", " ylim=(0.2, 1),\n", " ratio=(7.2, 6),\n", " colors=[lightblue_hex, lightpink_hex],\n", @@ -18920,8 +18920,9 @@ "metadata": {}, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\n", "import seaborn as sns\n", + "from matplotlib import pyplot as plt\n", + "\n", "from fiora.visualization.define_colors import *" ] }, @@ -18938,11 +18939,11 @@ "metadata": {}, "outputs": [], "source": [ - "df_test[\"group_id\"] = df_test[\"group_id\"].astype(int)\n", - "df_test.drop_duplicates(\"group_id\", keep=\"first\")[[\"group_id\", \"SMILES\"]].to_csv(\n", - " f\"{home}/data/metabolites/benchmarking/classyfire_input.csv\",\n", + "df_test['group_id'] = df_test['group_id'].astype(int)\n", + "df_test.drop_duplicates('group_id', keep='first')[['group_id', 'SMILES']].to_csv(\n", + " f'{home}/data/metabolites/benchmarking/classyfire_input.csv',\n", " header=None,\n", - " sep=\" \",\n", + " sep=' ',\n", " index=False,\n", ")\n", "# Use classyfire via text interface to produce output csv: http://classyfire.wishartlab.com/#chemical-text-query" @@ -18955,14 +18956,14 @@ "outputs": [], "source": [ "compound_classes = pd.read_csv(\n", - " f\"{home}/data/metabolites/benchmarking/classyfire_output.csv\"\n", + " f'{home}/data/metabolites/benchmarking/classyfire_output.csv'\n", ")\n", - "compound_classes[\"CompoundID\"] = pd.to_numeric(\n", - " compound_classes[\"CompoundID\"], errors=\"coerce\", downcast=\"integer\"\n", + "compound_classes['CompoundID'] = pd.to_numeric(\n", + " compound_classes['CompoundID'], errors='coerce', downcast='integer'\n", ")\n", - "compound_classes[[\"Category\", \"Value\"]] = compound_classes[\n", - " \"ClassifiedResults\"\n", - "].str.split(\":\", n=1, expand=True)" + "compound_classes[['Category', 'Value']] = compound_classes[\n", + " 'ClassifiedResults'\n", + "].str.split(':', n=1, expand=True)" ] }, { @@ -18971,10 +18972,10 @@ "metadata": {}, "outputs": [], "source": [ - "compound_classes[\"Value\"] = compound_classes[\"Value\"].fillna(\"\")\n", + "compound_classes['Value'] = compound_classes['Value'].fillna('')\n", "compound_classes = (\n", - " compound_classes.groupby([\"CompoundID\", \"Category\"])[\"Value\"]\n", - " .agg(\",\".join)\n", + " compound_classes.groupby(['CompoundID', 'Category'])['Value']\n", + " .agg(','.join)\n", " .unstack()\n", ")\n", "compound_classes.reset_index(inplace=True)\n", @@ -19118,9 +19119,9 @@ "metadata": {}, "outputs": [], "source": [ - "num_classes = len(compound_classes[\"Superclass\"].unique())\n", + "num_classes = len(compound_classes['Superclass'].unique())\n", "superclass_map = dict(\n", - " zip(compound_classes[\"Superclass\"].unique(), range(0, num_classes))\n", + " zip(compound_classes['Superclass'].unique(), range(0, num_classes))\n", ")" ] }, @@ -19130,7 +19131,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_em = {\"embedding\": [], \"spectrum\": []}" + "df_em = {'embedding': [], 'spectrum': []}" ] }, { @@ -19148,19 +19149,19 @@ "classmap = {}\n", "supermap = {}\n", "submap = {}\n", - "pal = sns.color_palette(\"viridis\", num_classes)\n", + "pal = sns.color_palette('viridis', num_classes)\n", "\n", - "for i, d in df_test.drop_duplicates(\"group_id\", keep=\"first\").iterrows():\n", - " metabolite = d[\"Metabolite\"]\n", - " group_id = d[\"group_id\"]\n", - " superclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\n", - " \"Superclass\"\n", + "for i, d in df_test.drop_duplicates('group_id', keep='first').iterrows():\n", + " metabolite = d['Metabolite']\n", + " group_id = d['group_id']\n", + " superclass = compound_classes[compound_classes['CompoundID'] == group_id].iloc[0][\n", + " 'Superclass'\n", " ]\n", - " subclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\n", - " \"Subclass\"\n", + " subclass = compound_classes[compound_classes['CompoundID'] == group_id].iloc[0][\n", + " 'Subclass'\n", " ]\n", - " cclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\n", - " \"Class\"\n", + " cclass = compound_classes[compound_classes['CompoundID'] == group_id].iloc[0][\n", + " 'Class'\n", " ]\n", "\n", " supermap[group_id] = superclass\n", @@ -19170,12 +19171,12 @@ " data = metabolite.as_geometric_data(with_labels=False).to(dev)\n", " batch = geom.data.Batch.from_data_list([data])\n", " embedding = model.get_graph_embedding(batch)\n", - " if d[\"Precursor_type\"] == \"[M+H]+\":\n", - " df_em[\"embedding\"] += [embedding.flatten().cpu().detach().numpy()]\n", - " df_em[\"spectrum\"] += [d[\"peaks\"]]\n", + " if d['Precursor_type'] == '[M+H]+':\n", + " df_em['embedding'] += [embedding.flatten().cpu().detach().numpy()]\n", + " df_em['spectrum'] += [d['peaks']]\n", "\n", " test_group_id += [group_id]\n", - " test_smiles += [d[\"SMILES\"]]\n", + " test_smiles += [d['SMILES']]\n", " test_embeddings += [embedding.flatten().cpu().detach().numpy()]\n", " test_classes += [superclass]\n", " test_cclasses += [cclass]\n", @@ -19183,9 +19184,9 @@ " colors += [pal[superclass_map[superclass]]]\n", "\n", "\n", - "df_test[\"Superclass\"] = df_test[\"group_id\"].map(supermap)\n", - "df_test[\"Class\"] = df_test[\"group_id\"].map(classmap)\n", - "df_test[\"Subclass\"] = df_test[\"group_id\"].map(submap)" + "df_test['Superclass'] = df_test['group_id'].map(supermap)\n", + "df_test['Class'] = df_test['group_id'].map(classmap)\n", + "df_test['Subclass'] = df_test['group_id'].map(submap)" ] }, { @@ -19196,11 +19197,11 @@ "source": [ "Embedding_DF = pd.DataFrame(\n", " {\n", - " \"group_id\": test_group_id,\n", - " \"SMILES\": test_smiles,\n", - " \"Superclass\": test_classes,\n", - " \"Compound Class\": test_cclasses,\n", - " \"Embedding\": test_embeddings,\n", + " 'group_id': test_group_id,\n", + " 'SMILES': test_smiles,\n", + " 'Superclass': test_classes,\n", + " 'Compound Class': test_cclasses,\n", + " 'Embedding': test_embeddings,\n", " }\n", ")" ] @@ -19229,19 +19230,19 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MS.spectral_scores import spectral_cosine, cosine\n", + "from fiora.MS.spectral_scores import cosine, spectral_cosine\n", "\n", "em_sim = []\n", "spec_sim = []\n", "\n", - "for i in range(len(df_em[\"embedding\"])):\n", - " em = df_em[\"embedding\"][i]\n", - " spec = df_em[\"spectrum\"][i]\n", - " for j in range(len(df_em[\"embedding\"])):\n", + "for i in range(len(df_em['embedding'])):\n", + " em = df_em['embedding'][i]\n", + " spec = df_em['spectrum'][i]\n", + " for j in range(len(df_em['embedding'])):\n", " if i >= j:\n", " continue\n", - " em2 = df_em[\"embedding\"][j]\n", - " spec2 = df_em[\"spectrum\"][j]\n", + " em2 = df_em['embedding'][j]\n", + " spec2 = df_em['spectrum'][j]\n", " spec_sim += [spectral_cosine(spec, spec2, transform=np.sqrt)]\n", " em_sim += [cosine(em, em2)]" ] @@ -19284,7 +19285,7 @@ } ], "source": [ - "raise KeyboardInterrupt(\"Stop before UMAPs - Careful now\")\n", + "raise KeyboardInterrupt('Stop before UMAPs - Careful now')\n", "# top = df_test.drop_duplicates(\"group_id\", keep=\"first\")[\"Superclass\"].value_counts().index[:8]\n", "# test_classes = [t if t in top else \" Other\" for t in test_classes]" ] @@ -19295,7 +19296,7 @@ "metadata": {}, "outputs": [], "source": [ - "lipid_index = np.where(np.array(test_classes) == \" Lipids and lipid-like molecules\")" + "lipid_index = np.where(np.array(test_classes) == ' Lipids and lipid-like molecules')" ] }, { @@ -19338,7 +19339,7 @@ } ], "source": [ - "df_test.drop_duplicates(\"group_id\", keep=\"first\")[\"Superclass\"].value_counts()" + "df_test.drop_duplicates('group_id', keep='first')['Superclass'].value_counts()" ] }, { @@ -19463,18 +19464,18 @@ " x=e[:, 0],\n", " y=e[:, 1],\n", " hue=test_classes,\n", - " edgecolor=\"white\",\n", + " edgecolor='white',\n", " linewidth=0.48,\n", " s=40,\n", - " palette=sns.color_palette(\"husl\", 13),\n", + " palette=sns.color_palette('husl', 13),\n", " style=test_classes,\n", - " markers=[\"o\", (4, 0, 45), \"D\", \"v\", \"p\", (4, 1, 0), \"X\"],\n", + " markers=['o', (4, 0, 45), 'D', 'v', 'p', (4, 1, 0), 'X'],\n", ") # markers=[\"o\", \"*\", \"s\", \"^\", \"D\", \"h\", (4,1,0),\"v\", \"X\", \"P\", \"p\", \"<\", (4,0, 45)] s=30,, hue_order=np.unique(test_classes)[::-1])#, order=[' Organic acids and derivatives', ' Organoheterocyclic compounds', ' Benzenoids', ' Alkaloids and derivatives', ' Phenylpropanoids and polyketides', ])#, palette=sns.color_palette(\"colorblind\") + [\"black\", \"gray\", \"white\"])\n", - "legend = ax.legend(loc=\"lower left\", bbox_to_anchor=(1, 0.5))\n", + "legend = ax.legend(loc='lower left', bbox_to_anchor=(1, 0.5))\n", "# plt.gca().set_aspect('equal', 'datalim')\n", "# plt.ylim([-2.50,13])\n", "# plt.xlim([-2.50,13])\n", - "ax.set_aspect(\"equal\", \"datalim\")\n", + "ax.set_aspect('equal', 'datalim')\n", "print(ax.get_xlim())\n", "ax.set_xlim((4.18, 14.15))\n", "ax.set_ylim(ax.get_xlim())\n", @@ -19482,7 +19483,7 @@ "default_marker_size = ax.collections[0].get_sizes()[0]\n", "\n", "# Print the default marker size\n", - "print(\"Default marker size:\", default_marker_size)\n", + "print('Default marker size:', default_marker_size)\n", "# ax.set_ylim([4, 12])\n", "# fig.savefig(f\"{home}/images/paper/umap_alt2.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/umap_alt2.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -19491,7 +19492,7 @@ "default_line_width = ax.collections[0].get_linewidths()[0]\n", "\n", "# Print the default line width\n", - "print(\"Default line width (edges):\", default_line_width)\n", + "print('Default line width (edges):', default_line_width)\n", "\n", "# Get the legend handles and labels\n", "handles = legend.legend_handles\n", @@ -19548,7 +19549,6 @@ } ], "source": [ - "import matplotlib\n", "import umap\n", "\n", "reset_matplotlib()\n", @@ -19564,23 +19564,23 @@ " x=e[lipid_index[0], 0],\n", " y=e[lipid_index[0], 1],\n", " hue=np.array(test_cclasses)[lipid_index[0]],\n", - " edgecolor=\"white\",\n", + " edgecolor='white',\n", " linewidth=0.48,\n", " s=40,\n", - " palette=sns.color_palette(\"husl\", 6),\n", + " palette=sns.color_palette('husl', 6),\n", " style=np.array(test_cclasses)[lipid_index[0]],\n", - " markers=[\"o\", (4, 0, 45), \"D\", \"v\", \"p\", (4, 1, 0), \"X\"],\n", + " markers=['o', (4, 0, 45), 'D', 'v', 'p', (4, 1, 0), 'X'],\n", ") # markers=[\"o\", \"*\", \"s\", \"^\", \"D\", \"h\", (4,1,0),\"v\", \"X\", \"P\", \"p\", \"<\", (4,0, 45)] s=30,, hue_order=np.unique(test_classes)[::-1])#, order=[' Organic acids and derivatives', ' Organoheterocyclic compounds', ' Benzenoids', ' Alkaloids and derivatives', ' Phenylpropanoids and polyketides', ])#, palette=sns.color_palette(\"colorblind\") + [\"black\", \"gray\", \"white\"])\n", - "legend = ax.legend(loc=\"lower left\", bbox_to_anchor=(1, 0.5))\n", + "legend = ax.legend(loc='lower left', bbox_to_anchor=(1, 0.5))\n", "# plt.gca().set_aspect('equal', 'datalim')\n", "# plt.ylim([-2.50,13])\n", "# plt.xlim([-2.50,13])\n", - "ax.set_aspect(\"equal\", \"datalim\")\n", + "ax.set_aspect('equal', 'datalim')\n", "# Get the default marker size\n", "default_marker_size = ax.collections[0].get_sizes()[0]\n", "\n", "# Print the default marker size\n", - "print(\"Default marker size:\", default_marker_size)\n", + "print('Default marker size:', default_marker_size)\n", "# ax.set_ylim([4, 12])\n", "# fig.savefig(f\"{home}/images/paper/umap_lipids_global.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", @@ -19588,7 +19588,7 @@ "default_line_width = ax.collections[0].get_linewidths()[0]\n", "\n", "# Print the default line width\n", - "print(\"Default line width (edges):\", default_line_width)\n", + "print('Default line width (edges):', default_line_width)\n", "\n", "# Get the legend handles and labels\n", "handles = legend.legend_handles\n", @@ -19633,9 +19633,9 @@ ], "source": [ "# TODO use as mask to plot lipids\n", - "df_test.drop_duplicates(\"group_id\", keep=\"first\")[\n", - " \"Superclass\"\n", - "] == \" Lipids and lipid-like molecules\"" + "df_test.drop_duplicates('group_id', keep='first')[\n", + " 'Superclass'\n", + "] == ' Lipids and lipid-like molecules'" ] }, { @@ -19675,7 +19675,7 @@ } ], "source": [ - "sns.color_palette(\"muted\")" + "sns.color_palette('muted')" ] }, { @@ -19701,22 +19701,22 @@ "\n", "ax = sns.boxplot(\n", " ax=ax,\n", - " data=df_test[df_test[\"Superclass\"] != \"nan\"],\n", - " y=\"spectral_sqrt_cosine\",\n", + " data=df_test[df_test['Superclass'] != 'nan'],\n", + " y='spectral_sqrt_cosine',\n", " dodge=True,\n", " width=0.9,\n", " linewidth=1.5,\n", - " hue=\"Superclass\",\n", + " hue='Superclass',\n", " legend=False,\n", " palette=legend_colors,\n", " showfliers=False,\n", ")\n", - "ax.set_ylabel(\"cosine similarity\", fontsize=14)\n", + "ax.set_ylabel('cosine similarity', fontsize=14)\n", "ax.set_ylim([0, 1])\n", - "plt.tick_params(axis=\"y\", labelsize=14)\n", + "plt.tick_params(axis='y', labelsize=14)\n", "ax.set_xticks([])\n", "sns.despine(offset=10, trim=True)\n", - "ax.spines[\"bottom\"].set_visible(False)\n", + "ax.spines['bottom'].set_visible(False)\n", "\n", "adjust_box_widths(ax, 0.85)\n", "\n", @@ -19742,19 +19742,19 @@ } ], "source": [ - "from fiora.visualization.define_colors import adjust_box_widths, adjust_bar_widths\n", + "from fiora.visualization.define_colors import adjust_bar_widths, adjust_box_widths\n", "\n", "fig, axs = plt.subplots(2, 1, figsize=(10, 9), sharex=True, height_ratios=[1, 4])\n", "plt.subplots_adjust(hspace=0.1) # top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", "superclass_data = {}\n", - "for i, superclass in enumerate(df_test[\"Superclass\"].unique()):\n", - " if superclass != \"nan\":\n", + "for i, superclass in enumerate(df_test['Superclass'].unique()):\n", + " if superclass != 'nan':\n", " superclass_data[i] = {\n", - " \"Superclass\": superclass,\n", - " \"num_spectra\": sum(df_test[\"Superclass\"] == superclass),\n", - " \"num_compounds\": df_test[df_test[\"Superclass\"] == superclass][\n", - " \"group_id\"\n", + " 'Superclass': superclass,\n", + " 'num_spectra': sum(df_test['Superclass'] == superclass),\n", + " 'num_compounds': df_test[df_test['Superclass'] == superclass][\n", + " 'group_id'\n", " ].nunique(),\n", " }\n", "\n", @@ -19763,59 +19763,59 @@ "axs[0] = sns.barplot(\n", " ax=axs[0],\n", " data=superclass_data,\n", - " y=\"num_compounds\",\n", + " y='num_compounds',\n", " dodge=True,\n", " width=0.9,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " linewidth=1.5,\n", - " hue=\"Superclass\",\n", + " hue='Superclass',\n", " legend=False,\n", " palette=legend_colors,\n", ")\n", "for i, container in enumerate(axs[0].containers):\n", " axs[0].bar_label(axs[0].containers[i], fontsize=18)\n", " for bar in container:\n", - " bar_label = f\"n={superclass_data.iloc[i]['num_spectra']}\"\n", + " bar_label = f'n={superclass_data.iloc[i][\"num_spectra\"]}'\n", " axs[0].text(\n", " bar.get_x() + bar.get_width() / 2,\n", " -0.07, # Position at the base (y=0)\n", " bar_label,\n", - " ha=\"center\",\n", - " va=\"top\",\n", + " ha='center',\n", + " va='top',\n", " fontsize=14,\n", " )\n", "\n", - "axs[0].set_ylabel(\"\", fontsize=14)\n", - "axs[0].spines[\"bottom\"].set_visible(False)\n", - "axs[0].tick_params(axis=\"y\", labelsize=14)\n", + "axs[0].set_ylabel('', fontsize=14)\n", + "axs[0].spines['bottom'].set_visible(False)\n", + "axs[0].tick_params(axis='y', labelsize=14)\n", "adjust_bar_widths(axs[0], 0.85)\n", "\n", "\n", "axs[1] = sns.boxplot(\n", " ax=axs[1],\n", - " data=df_test[df_test[\"Superclass\"] != \"nan\"],\n", - " y=\"spectral_sqrt_cosine\",\n", + " data=df_test[df_test['Superclass'] != 'nan'],\n", + " y='spectral_sqrt_cosine',\n", " dodge=True,\n", " width=0.9,\n", " linewidth=1.5,\n", - " hue=\"Superclass\",\n", + " hue='Superclass',\n", " legend=False,\n", " palette=legend_colors,\n", " showfliers=False,\n", ")\n", - "axs[1].set_ylabel(\"\") # \"cosine similarity\", fontsize=14)\n", + "axs[1].set_ylabel('') # \"cosine similarity\", fontsize=14)\n", "axs[1].set_ylim([0, 1])\n", - "plt.tick_params(axis=\"y\", labelsize=14)\n", + "plt.tick_params(axis='y', labelsize=14)\n", "axs[1].set_xticks([])\n", "sns.despine(offset=10, trim=True)\n", - "axs[1].spines[\"bottom\"].set_visible(False)\n", + "axs[1].spines['bottom'].set_visible(False)\n", "adjust_box_widths(axs[1], 0.85)\n", "\n", - "plt.rc(\"axes\", labelsize=18)\n", - "plt.rc(\"legend\", fontsize=18)\n", + "plt.rc('axes', labelsize=18)\n", + "plt.rc('legend', fontsize=18)\n", "\n", - "axs[0].tick_params(axis=\"both\", labelsize=18)\n", - "axs[1].tick_params(axis=\"both\", labelsize=18)\n", + "axs[0].tick_params(axis='both', labelsize=18)\n", + "axs[1].tick_params(axis='both', labelsize=18)\n", "\n", "\n", "# fig.savefig(f\"{home}/images/paper/cosine_by_class_+hist.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -19840,8 +19840,8 @@ "metadata": {}, "outputs": [], "source": [ - "df_test[[\"group_id\", \"Superclass\", \"spectral_sqrt_cosine\"]].to_excel(\n", - " f\"{home}/images/paper/F4c.xlsx\"\n", + "df_test[['group_id', 'Superclass', 'spectral_sqrt_cosine']].to_excel(\n", + " f'{home}/images/paper/F4c.xlsx'\n", ")" ] }, @@ -19865,10 +19865,10 @@ "# print(df_test.groupby(\"Superclass\").group_id.unique().apply(len))\n", "# print(df_test.groupby(\"Superclass\").spectral_sqrt_cosine.median())\n", "# #df_test.groupby(\"Superclass\").spectral_sqrt_cosine.median()\n", - "result = df_test.groupby(\"Superclass\").agg(\n", - " num=(\"group_id\", lambda x: len(x.unique())),\n", - " spec=(\"group_id\", lambda x: len(x)),\n", - " cos=(\"spectral_sqrt_cosine\", \"median\"),\n", + "result = df_test.groupby('Superclass').agg(\n", + " num=('group_id', lambda x: len(x.unique())),\n", + " spec=('group_id', lambda x: len(x)),\n", + " cos=('spectral_sqrt_cosine', 'median'),\n", ")\n", "print(result)" ] @@ -19896,22 +19896,22 @@ "\n", "ax = sns.boxplot(\n", " ax=ax,\n", - " data=df_test[df_test[\"Class\"] != \"nan\"],\n", - " y=\"spectral_sqrt_cosine\",\n", + " data=df_test[df_test['Class'] != 'nan'],\n", + " y='spectral_sqrt_cosine',\n", " dodge=True,\n", " width=0.9,\n", " linewidth=1.5,\n", - " hue=\"Class\",\n", + " hue='Class',\n", " legend=False,\n", " palette=legend_colors,\n", " showfliers=False,\n", ")\n", - "ax.set_ylabel(\"cosine similarity\", fontsize=14)\n", + "ax.set_ylabel('cosine similarity', fontsize=14)\n", "ax.set_ylim([0, 1])\n", - "plt.tick_params(axis=\"y\", labelsize=14)\n", + "plt.tick_params(axis='y', labelsize=14)\n", "ax.set_xticks([])\n", "sns.despine(offset=10, trim=True)\n", - "ax.spines[\"bottom\"].set_visible(False)\n", + "ax.spines['bottom'].set_visible(False)\n", "\n", "adjust_box_widths(fig, 0.85)\n", "# fig.savefig(f\"{home}/images/paper/cosine_by_class.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -19943,9 +19943,9 @@ } ], "source": [ - "df_test[df_test[\"Superclass\"] == \" Lipids and lipid-like molecules\"].drop_duplicates(\n", - " \"group_id\", keep=\"first\"\n", - ")[\"Class\"].value_counts()" + "df_test[df_test['Superclass'] == ' Lipids and lipid-like molecules'].drop_duplicates(\n", + " 'group_id', keep='first'\n", + ")['Class'].value_counts()" ] }, { @@ -19965,26 +19965,26 @@ } ], "source": [ - "num_classes = len(compound_classes[\"Class\"].unique())\n", - "class_map = dict(zip(compound_classes[\"Class\"].unique(), range(0, num_classes)))\n", + "num_classes = len(compound_classes['Class'].unique())\n", + "class_map = dict(zip(compound_classes['Class'].unique(), range(0, num_classes)))\n", "test_embeddings = []\n", "test_classes = []\n", "colors = []\n", "classmap = {}\n", "supermap = {}\n", "submap = {}\n", - "pal = sns.color_palette(\"viridis\", num_classes)\n", + "pal = sns.color_palette('viridis', num_classes)\n", "\n", "for i, d in (\n", - " df_test[df_test[\"Superclass\"] == \" Lipids and lipid-like molecules\"]\n", - " .drop_duplicates(\"group_id\", keep=\"first\")\n", + " df_test[df_test['Superclass'] == ' Lipids and lipid-like molecules']\n", + " .drop_duplicates('group_id', keep='first')\n", " .iterrows()\n", "):\n", - " metabolite = d[\"Metabolite\"]\n", - " group_id = d[\"group_id\"]\n", + " metabolite = d['Metabolite']\n", + " group_id = d['group_id']\n", "\n", - " cclass = compound_classes[compound_classes[\"CompoundID\"] == group_id].iloc[0][\n", - " \"Class\"\n", + " cclass = compound_classes[compound_classes['CompoundID'] == group_id].iloc[0][\n", + " 'Class'\n", " ]\n", "\n", " data = metabolite.as_geometric_data(with_labels=False).to(dev)\n", @@ -20013,7 +20013,6 @@ } ], "source": [ - "import matplotlib\n", "import umap\n", "\n", "reset_matplotlib()\n", @@ -20029,14 +20028,14 @@ " x=e[:, 0],\n", " y=e[:, 1],\n", " hue=test_classes,\n", - " edgecolor=\"white\",\n", + " edgecolor='white',\n", " linewidth=0.48,\n", " s=40,\n", - " palette=sns.color_palette(\"husl\", 6),\n", + " palette=sns.color_palette('husl', 6),\n", " style=test_classes,\n", - " markers=[\"o\", (4, 0, 45), \"D\", \"v\", \"p\", (4, 1, 0), \"X\"],\n", + " markers=['o', (4, 0, 45), 'D', 'v', 'p', (4, 1, 0), 'X'],\n", ") # markers=[\"o\", \"*\", \"s\", \"^\", \"D\", \"h\", (4,1,0),\"v\", \"X\", \"P\", \"p\", \"<\", (4,0, 45)] s=30,, hue_order=np.unique(test_classes)[::-1])#, order=[' Organic acids and derivatives', ' Organoheterocyclic compounds', ' Benzenoids', ' Alkaloids and derivatives', ' Phenylpropanoids and polyketides', ])#, palette=sns.color_palette(\"colorblind\") + [\"black\", \"gray\", \"white\"])\n", - "legend = ax.legend(loc=\"lower left\", bbox_to_anchor=(1, 0.5))\n", + "legend = ax.legend(loc='lower left', bbox_to_anchor=(1, 0.5))\n", "# plt.gca().set_aspect('equal', 'datalim')\n", "# plt.ylim([-2.50,13])\n", "# plt.xlim([-2.50,13])\n", @@ -20048,7 +20047,7 @@ "default_marker_size = ax.collections[0].get_sizes()[0]\n", "\n", "# Print the default marker size\n", - "print(\"Default marker size:\", default_marker_size)\n", + "print('Default marker size:', default_marker_size)\n", "# ax.set_ylim([4, 12])\n", "# fig.savefig(f\"{home}/images/paper/umap_lipids_local.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/umap_lipids_local.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20058,7 +20057,7 @@ "default_line_width = ax.collections[0].get_linewidths()[0]\n", "\n", "# Print the default line width\n", - "print(\"Default line width (edges):\", default_line_width)\n", + "print('Default line width (edges):', default_line_width)\n", "\n", "# Get the legend handles and labels\n", "handles = legend.legend_handles\n", @@ -20085,10 +20084,10 @@ "metadata": {}, "outputs": [], "source": [ - "df_cas[\"Dataset\"] = \"CASMI 16\"\n", - "df_cas22[\"Dataset\"] = \"CASMI 22\"\n", - "df_test[\"Dataset\"] = \"Test split\"\n", - "df_msnlib_test[\"Dataset\"] = \"MSnLib\"\n", + "df_cas['Dataset'] = 'CASMI 16'\n", + "df_cas22['Dataset'] = 'CASMI 22'\n", + "df_test['Dataset'] = 'Test split'\n", + "df_msnlib_test['Dataset'] = 'MSnLib'\n", "C = pd.concat([df_test, df_msnlib_test, df_cas, df_cas22], ignore_index=True)" ] }, @@ -20108,57 +20107,57 @@ "source": [ "calc_tanimoto = True # This may tak a long time\n", "if calc_tanimoto:\n", - " print(\"Calculating Tanimoto scores. This may take a while\")\n", - " df_cas.loc[:, \"tanimoto\"] = np.nan\n", + " print('Calculating Tanimoto scores. This may take a while')\n", + " df_cas.loc[:, 'tanimoto'] = np.nan\n", " for i, d in df_cas.iterrows():\n", - " df_cas.at[i, \"tanimoto\"] = (\n", - " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " df_cas.at[i, 'tanimoto'] = (\n", + " df_train[df_train['dataset'] == 'training']['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite']))\n", " .max()\n", " )\n", - " df_cas.at[i, \"tanimoto3\"] = (\n", - " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\"))\n", + " df_cas.at[i, 'tanimoto3'] = (\n", + " df_train[df_train['dataset'] == 'training']['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite'], finger='morgan3'))\n", " .max()\n", " )\n", "\n", - " df_cas22.loc[:, \"tanimoto\"] = np.nan\n", + " df_cas22.loc[:, 'tanimoto'] = np.nan\n", " for i, d in df_cas22.iterrows():\n", - " df_cas22.at[i, \"tanimoto\"] = (\n", - " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " df_cas22.at[i, 'tanimoto'] = (\n", + " df_train[df_train['dataset'] == 'training']['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite']))\n", " .max()\n", " )\n", - " df_cas22.at[i, \"tanimoto3\"] = (\n", - " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\"))\n", + " df_cas22.at[i, 'tanimoto3'] = (\n", + " df_train[df_train['dataset'] == 'training']['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite'], finger='morgan3'))\n", " .max()\n", " )\n", "\n", - " df_test.loc[:, \"tanimoto\"] = np.nan\n", + " df_test.loc[:, 'tanimoto'] = np.nan\n", " for i, d in df_test.iterrows():\n", - " df_test.at[i, \"tanimoto\"] = (\n", - " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " df_test.at[i, 'tanimoto'] = (\n", + " df_train[df_train['dataset'] == 'training']['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite']))\n", " .max()\n", " )\n", - " df_test.at[i, \"tanimoto3\"] = (\n", - " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\"))\n", + " df_test.at[i, 'tanimoto3'] = (\n", + " df_train[df_train['dataset'] == 'training']['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite'], finger='morgan3'))\n", " .max()\n", " )\n", "\n", - " df_msnlib_test.loc[:, \"tanimoto\"] = np.nan\n", - " df_msnlib_test.loc[:, \"tanimoto3\"] = np.nan\n", + " df_msnlib_test.loc[:, 'tanimoto'] = np.nan\n", + " df_msnlib_test.loc[:, 'tanimoto3'] = np.nan\n", " for i, d in df_msnlib_test.iterrows():\n", - " df_msnlib_test.at[i, \"tanimoto\"] = (\n", - " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"]))\n", + " df_msnlib_test.at[i, 'tanimoto'] = (\n", + " df_train[df_train['dataset'] == 'training']['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite']))\n", " .max()\n", " )\n", - " df_msnlib_test.at[i, \"tanimoto3\"] = (\n", - " df_train[df_train[\"dataset\"] == \"training\"][\"Metabolite\"]\n", - " .apply(lambda x: x.tanimoto_similarity(d[\"Metabolite\"], finger=\"morgan3\"))\n", + " df_msnlib_test.at[i, 'tanimoto3'] = (\n", + " df_train[df_train['dataset'] == 'training']['Metabolite']\n", + " .apply(lambda x: x.tanimoto_similarity(d['Metabolite'], finger='morgan3'))\n", " .max()\n", " )" ] @@ -20169,32 +20168,32 @@ "metadata": {}, "outputs": [], "source": [ - "df_test[\"group_id\"] = df_test[\"group_id\"].astype(int)\n", + "df_test['group_id'] = df_test['group_id'].astype(int)\n", "new_value_offset = 100000\n", "\n", "\n", "# Function to assign unique metabolite identifiers\n", "def assign_metabolite_ids(df_ref: pd.DataFrame, metabolite_id_map):\n", " for i, data in df_ref.iterrows():\n", - " metabolite = data[\"Metabolite\"]\n", + " metabolite = data['Metabolite']\n", " is_new = True\n", " for id, other in metabolite_id_map.items():\n", " if metabolite == other:\n", " # Set the metabolite ID in the dataframe\n", - " df_ref.loc[i, \"group_id\"] = int(id)\n", + " df_ref.loc[i, 'group_id'] = int(id)\n", " is_new = False\n", " break\n", " if is_new:\n", " new_id = new_value_offset + len(metabolite_id_map)\n", - " df_ref.loc[i, \"group_id\"] = int(new_id)\n", + " df_ref.loc[i, 'group_id'] = int(new_id)\n", " metabolite_id_map[int(new_id)] = metabolite\n", "\n", "\n", "if calc_tanimoto:\n", " # Initialize the metabolite_id_map with metabolites from df_test\n", " metabolite_id_map = {}\n", - " for group_id in df_test[\"group_id\"].unique():\n", - " metabolite = df_test.loc[df_test[\"group_id\"] == group_id, \"Metabolite\"].iloc[0]\n", + " for group_id in df_test['group_id'].unique():\n", + " metabolite = df_test.loc[df_test['group_id'] == group_id, 'Metabolite'].iloc[0]\n", " metabolite_id_map[int(group_id)] = metabolite\n", "\n", " # Apply the function to each dataframe\n", @@ -20202,9 +20201,9 @@ " assign_metabolite_ids(df_cas, metabolite_id_map)\n", " assign_metabolite_ids(df_cas22, metabolite_id_map)\n", "\n", - " df_msnlib_test[\"group_id\"] = df_msnlib_test[\"group_id\"].astype(int)\n", - " df_cas[\"group_id\"] = df_cas[\"group_id\"].astype(int)\n", - " df_cas22[\"group_id\"] = df_cas22[\"group_id\"].astype(int)" + " df_msnlib_test['group_id'] = df_msnlib_test['group_id'].astype(int)\n", + " df_cas['group_id'] = df_cas['group_id'].astype(int)\n", + " df_cas22['group_id'] = df_cas22['group_id'].astype(int)" ] }, { @@ -20225,20 +20224,20 @@ " ax=ax,\n", " data=C,\n", " x=pd.cut(\n", - " C[C[\"Precursor_type\"] == \"[M+H]+\"][\"tanimoto3\"],\n", + " C[C['Precursor_type'] == '[M+H]+']['tanimoto3'],\n", " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", " ),\n", - " y=\"spectral_sqrt_cosine\",\n", - " palette=sns.color_palette(\"bright\"),\n", + " y='spectral_sqrt_cosine',\n", + " palette=sns.color_palette('bright'),\n", " capsize=0.0,\n", - " hue=\"Dataset\",\n", + " hue='Dataset',\n", " dodge=0.25,\n", - " estimator=\"median\",\n", + " estimator='median',\n", " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " plt.ylim([0, 1])\n", - " plt.legend(title=\"Dataset\", loc=\"lower left\")\n", - " plt.ylabel(\"Cosine similarity\")\n", - " plt.xlabel(\"Tanimoto similarity\")\n", + " plt.legend(title='Dataset', loc='lower left')\n", + " plt.ylabel('Cosine similarity')\n", + " plt.xlabel('Tanimoto similarity')\n", "\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_withcasmi.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_withcasmi.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20260,57 +20259,57 @@ " ax=ax,\n", " data=C,\n", " x=pd.cut(\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", - " \"tanimoto3\"\n", + " C[(C['Precursor_type'] == '[M+H]+') & (C['Dataset'] == 'Test split')][\n", + " 'tanimoto3'\n", " ],\n", " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", " ),\n", - " y=\"spectral_sqrt_cosine\",\n", + " y='spectral_sqrt_cosine',\n", " capsize=0.0,\n", " color=lightblue_hex,\n", - " estimator=\"median\",\n", + " estimator='median',\n", " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " sns.pointplot(\n", " ax=ax,\n", " data=C,\n", " x=pd.cut(\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", - " \"tanimoto3\"\n", + " C[(C['Precursor_type'] == '[M+H]+') & (C['Dataset'] == 'Test split')][\n", + " 'tanimoto3'\n", " ],\n", " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", " ),\n", - " y=\"cfm_sqrt_cosine\",\n", + " y='cfm_sqrt_cosine',\n", " capsize=0.0,\n", - " color=\"gray\",\n", - " estimator=\"median\",\n", - " linestyle=\"--\",\n", + " color='gray',\n", + " estimator='median',\n", + " linestyle='--',\n", " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " sns.pointplot(\n", " ax=ax,\n", " data=C,\n", " x=pd.cut(\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", - " \"tanimoto3\"\n", + " C[(C['Precursor_type'] == '[M+H]+') & (C['Dataset'] == 'Test split')][\n", + " 'tanimoto3'\n", " ],\n", " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", " ),\n", - " y=\"ice_sqrt_cosine\",\n", + " y='ice_sqrt_cosine',\n", " capsize=0.0,\n", " color=lightpink_hex,\n", - " estimator=\"median\",\n", + " estimator='median',\n", " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " plt.ylim([0.4, 1])\n", " custom_lines = [\n", - " plt.Line2D([0], [0], color=lightblue_hex, linestyle=\"-\", marker=\"o\"),\n", - " plt.Line2D([0], [0], color=\"gray\", linestyle=\"--\", marker=\"o\"),\n", - " plt.Line2D([0], [0], color=lightpink_hex, linestyle=\"-\", marker=\"o\"),\n", + " plt.Line2D([0], [0], color=lightblue_hex, linestyle='-', marker='o'),\n", + " plt.Line2D([0], [0], color='gray', linestyle='--', marker='o'),\n", + " plt.Line2D([0], [0], color=lightpink_hex, linestyle='-', marker='o'),\n", " ]\n", " plt.legend(\n", - " custom_lines, [\"Fiora\", \"CFM-ID\", \"ICEBERG\"], title=\"Software\", loc=\"upper left\"\n", + " custom_lines, ['Fiora', 'CFM-ID', 'ICEBERG'], title='Software', loc='upper left'\n", " )\n", "\n", - " plt.ylabel(\"Cosine similarity\")\n", - " plt.xlabel(\"Tanimoto similarity\")\n", + " plt.ylabel('Cosine similarity')\n", + " plt.xlabel('Tanimoto similarity')\n", " # fig.savefig(f\"{home}/images/paper/tanimoto.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20324,41 +20323,41 @@ "metadata": {}, "outputs": [], "source": [ - "datasets = [\"Test split\", \"MSnLib\", \"CASMI 16\", \"CASMI 22\"]\n", - "score_func = \"spectral_sqrt_cosine\"\n", + "datasets = ['Test split', 'MSnLib', 'CASMI 16', 'CASMI 22']\n", + "score_func = 'spectral_sqrt_cosine'\n", "if calc_tanimoto:\n", " fig, axs = plt.subplots(2, 1, figsize=(12, 9), height_ratios=[1, 4], sharex=True)\n", " plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08, hspace=0.08)\n", "\n", " binned_data = pd.cut(\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)][\"tanimoto3\"],\n", + " C[(C['Precursor_type'] == '[M+H]+') & C['Dataset'].isin(datasets)]['tanimoto3'],\n", " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", " ).dropna()\n", " grouped_data = (\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)]\n", - " .groupby(binned_data)[\"group_id\"]\n", + " C[(C['Precursor_type'] == '[M+H]+') & C['Dataset'].isin(datasets)]\n", + " .groupby(binned_data)['group_id']\n", " .nunique()\n", - " .reset_index(name=\"unique_group_ids\")\n", + " .reset_index(name='unique_group_ids')\n", " )\n", "\n", " sns.barplot(\n", " ax=axs[0],\n", " x=binned_data.cat.categories,\n", - " y=\"unique_group_ids\",\n", + " y='unique_group_ids',\n", " data=grouped_data,\n", " linewidth=1.5,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " )\n", "\n", " for i in range(len(axs[0].containers)):\n", " axs[0].bar_label(axs[0].containers[i], fontsize=12)\n", - " axs[0].set_ylabel(\"\", fontsize=14)\n", - " axs[0].set_xlabel(\"\", fontsize=14)\n", - " axs[0].spines[\"top\"].set_visible(False)\n", - " axs[0].spines[\"right\"].set_visible(True)\n", - " axs[0].tick_params(axis=\"y\", labelsize=12)\n", + " axs[0].set_ylabel('', fontsize=14)\n", + " axs[0].set_xlabel('', fontsize=14)\n", + " axs[0].spines['top'].set_visible(False)\n", + " axs[0].spines['right'].set_visible(True)\n", + " axs[0].tick_params(axis='y', labelsize=12)\n", " adjust_bar_widths(axs[0], 0.5)\n", - " err = (\"ci\", 95)\n", + " err = ('ci', 95)\n", " sns.pointplot(\n", " ax=axs[1],\n", " data=C,\n", @@ -20366,44 +20365,44 @@ " y=score_func,\n", " capsize=0.0,\n", " color=lightblue_hex,\n", - " estimator=\"median\",\n", + " estimator='median',\n", " errorbar=err,\n", " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " sns.pointplot(\n", " ax=axs[1],\n", " data=C,\n", " x=binned_data,\n", - " y=score_func.replace(\"spectral\", \"cfm\"),\n", + " y=score_func.replace('spectral', 'cfm'),\n", " capsize=0.0,\n", - " color=\"gray\",\n", - " estimator=\"median\",\n", - " linestyle=\"--\",\n", + " color='gray',\n", + " estimator='median',\n", + " linestyle='--',\n", " errorbar=err,\n", " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " sns.pointplot(\n", " ax=axs[1],\n", " data=C,\n", " x=binned_data,\n", - " y=score_func.replace(\"spectral\", \"ice\"),\n", + " y=score_func.replace('spectral', 'ice'),\n", " capsize=0.0,\n", " color=lightpink_hex,\n", - " estimator=\"median\",\n", + " estimator='median',\n", " errorbar=err,\n", " ) # multiple=\"dodge\", common_norm=False, stat=\"density\") #multiple=\"dodge\", common_norm=False, stat=\"density\")\n", " plt.ylim([0.375, 1])\n", " custom_lines = [\n", - " plt.Line2D([0], [0], color=lightblue_hex, linestyle=\"-\", marker=\"o\"),\n", - " plt.Line2D([0], [0], color=\"gray\", linestyle=\"--\", marker=\"o\"),\n", - " plt.Line2D([0], [0], color=lightpink_hex, linestyle=\"-\", marker=\"o\"),\n", + " plt.Line2D([0], [0], color=lightblue_hex, linestyle='-', marker='o'),\n", + " plt.Line2D([0], [0], color='gray', linestyle='--', marker='o'),\n", + " plt.Line2D([0], [0], color=lightpink_hex, linestyle='-', marker='o'),\n", " ]\n", " plt.legend(\n", - " custom_lines, [\"Fiora\", \"CFM-ID\", \"ICEBERG\"], title=\"Software\", loc=\"upper left\"\n", + " custom_lines, ['Fiora', 'CFM-ID', 'ICEBERG'], title='Software', loc='upper left'\n", " )\n", "\n", - " plt.ylabel(\"Cosine similarity\")\n", - " plt.xlabel(\"Tanimoto similarity\")\n", - " axs[1].spines[\"top\"].set_visible(True)\n", - " axs[1].tick_params(axis=\"y\", labelsize=12)\n", + " plt.ylabel('Cosine similarity')\n", + " plt.xlabel('Tanimoto similarity')\n", + " axs[1].spines['top'].set_visible(True)\n", + " axs[1].tick_params(axis='y', labelsize=12)\n", "\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_+hist.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_+hist.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20431,28 +20430,28 @@ ], "source": [ "# Filter the DataFrame based on Precursor_type and Dataset\n", - "C = C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)]\n", + "C = C[(C['Precursor_type'] == '[M+H]+') & C['Dataset'].isin(datasets)]\n", "score_cols = [\n", - " \"spectral_sqrt_cosine_wo_prec\",\n", - " \"ice_sqrt_cosine_wo_prec\",\n", - " \"cfm_sqrt_cosine_wo_prec\",\n", + " 'spectral_sqrt_cosine_wo_prec',\n", + " 'ice_sqrt_cosine_wo_prec',\n", + " 'cfm_sqrt_cosine_wo_prec',\n", "] # [\"spectral_sqrt_cosine_wo_prec\", \"ice_sqrt_cosine_wo_prec\", \"cfm_sqrt_cosine_wo_prec\"]\n", "\n", "# Melt the DataFrame to explode the scores into a long format\n", "melted_df = pd.melt(\n", " C,\n", - " id_vars=[\"tanimoto3\", \"Precursor_type\", \"Dataset\"],\n", + " id_vars=['tanimoto3', 'Precursor_type', 'Dataset'],\n", " value_vars=score_cols,\n", - " var_name=\"Software\",\n", - " value_name=\"Score\",\n", + " var_name='Software',\n", + " value_name='Score',\n", ")\n", "\n", "# Map the software names to more descriptive names\n", - "melted_df[\"Software\"] = melted_df[\"Software\"].map(\n", + "melted_df['Software'] = melted_df['Software'].map(\n", " {\n", - " \"spectral_sqrt_cosine_wo_prec\": \"Fiora\", # _wo_prec\n", - " \"ice_sqrt_cosine_wo_prec\": \"ICEBERG\",\n", - " \"cfm_sqrt_cosine_wo_prec\": \"CFM-ID\",\n", + " 'spectral_sqrt_cosine_wo_prec': 'Fiora', # _wo_prec\n", + " 'ice_sqrt_cosine_wo_prec': 'ICEBERG',\n", + " 'cfm_sqrt_cosine_wo_prec': 'CFM-ID',\n", " }\n", ")\n", "\n", @@ -20482,47 +20481,47 @@ "metadata": {}, "outputs": [], "source": [ - "datasets = [\"Test split\", \"MSnLib\", \"CASMI 16\", \"CASMI 22\"]\n", + "datasets = ['Test split', 'MSnLib', 'CASMI 16', 'CASMI 22']\n", "if calc_tanimoto:\n", " fig, axs = plt.subplots(2, 1, figsize=(12, 9), height_ratios=[1, 4], sharex=True)\n", " plt.subplots_adjust(top=0.94, bottom=0.12, right=0.97, left=0.08, hspace=0.08)\n", "\n", " binned_data = pd.cut(\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)][\"tanimoto3\"],\n", + " C[(C['Precursor_type'] == '[M+H]+') & C['Dataset'].isin(datasets)]['tanimoto3'],\n", " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", " ).dropna()\n", " grouped_data = (\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)]\n", - " .groupby(binned_data)[\"group_id\"]\n", + " C[(C['Precursor_type'] == '[M+H]+') & C['Dataset'].isin(datasets)]\n", + " .groupby(binned_data)['group_id']\n", " .nunique()\n", - " .reset_index(name=\"unique_group_ids\")\n", + " .reset_index(name='unique_group_ids')\n", " )\n", "\n", " sns.barplot(\n", " ax=axs[0],\n", " x=binned_data.cat.categories,\n", - " y=\"unique_group_ids\",\n", + " y='unique_group_ids',\n", " data=grouped_data,\n", " linewidth=1.5,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " )\n", "\n", " # Prepare binned data from melted_df\n", " melted_df = melted_df[\n", - " melted_df[\"Dataset\"].isin(datasets)\n", + " melted_df['Dataset'].isin(datasets)\n", " ] # Use only relevant datasets\n", " binned_data = pd.cut(\n", - " melted_df[\"tanimoto3\"], bins=[x / 10.0 for x in range(2, 11)]\n", + " melted_df['tanimoto3'], bins=[x / 10.0 for x in range(2, 11)]\n", " ).dropna()\n", "\n", " # Add the binned data as a new column to melted_df\n", - " melted_df[\"Binned\"] = binned_data\n", + " melted_df['Binned'] = binned_data\n", "\n", " # Group for bar plot\n", " grouped_data = (\n", - " melted_df.groupby(\"Binned\")[\"Software\"]\n", + " melted_df.groupby('Binned')['Software']\n", " .nunique()\n", - " .reset_index(name=\"unique_group_ids\")\n", + " .reset_index(name='unique_group_ids')\n", " )\n", "\n", " # Bar plot\n", @@ -20530,46 +20529,46 @@ "\n", " for i in range(len(axs[0].containers)):\n", " axs[0].bar_label(axs[0].containers[i], fontsize=12)\n", - " axs[0].set_ylabel(\"\", fontsize=14)\n", - " axs[0].set_xlabel(\"\", fontsize=14)\n", - " axs[0].spines[\"top\"].set_visible(False)\n", - " axs[0].spines[\"right\"].set_visible(True)\n", - " axs[0].tick_params(axis=\"y\", labelsize=12)\n", + " axs[0].set_ylabel('', fontsize=14)\n", + " axs[0].set_xlabel('', fontsize=14)\n", + " axs[0].spines['top'].set_visible(False)\n", + " axs[0].spines['right'].set_visible(True)\n", + " axs[0].tick_params(axis='y', labelsize=12)\n", " adjust_bar_widths(axs[0], 0.5)\n", "\n", - " err = (\"pi\", 50) # Error bar type\n", + " err = ('pi', 50) # Error bar type\n", "\n", " # Point plots with hue for Software\n", " sns.pointplot(\n", " ax=axs[1],\n", " data=melted_df,\n", - " x=\"Binned\",\n", - " y=\"Score\",\n", - " hue=\"Software\",\n", + " x='Binned',\n", + " y='Score',\n", + " hue='Software',\n", " capsize=0.0,\n", - " linestyle=[\"-\", \"-\", \"--\"],\n", - " palette=[lightblue_hex, lightpink_hex, \"gray\"],\n", - " estimator=\"median\",\n", + " linestyle=['-', '-', '--'],\n", + " palette=[lightblue_hex, lightpink_hex, 'gray'],\n", + " estimator='median',\n", " errorbar=err,\n", " dodge=0.25,\n", " )\n", " plt.ylim([0.2, 1])\n", "\n", - " plt.legend(loc=\"upper left\")\n", + " plt.legend(loc='upper left')\n", "\n", - " plt.ylabel(\"Cosine similarity\")\n", - " plt.xlabel(\"Tanimoto similarity\")\n", - " axs[1].spines[\"top\"].set_visible(True)\n", - " axs[1].tick_params(axis=\"both\", labelsize=13)\n", - " plt.rc(\"axes\", labelsize=14)\n", - " plt.rc(\"legend\", fontsize=14)\n", + " plt.ylabel('Cosine similarity')\n", + " plt.xlabel('Tanimoto similarity')\n", + " axs[1].spines['top'].set_visible(True)\n", + " axs[1].tick_params(axis='both', labelsize=13)\n", + " plt.rc('axes', labelsize=14)\n", + " plt.rc('legend', fontsize=14)\n", "\n", " # Count the number of samples in each Binned group\n", - " sample_counts = melted_df.groupby(\"Binned\").size()\n", + " sample_counts = melted_df.groupby('Binned').size()\n", "\n", " # Add annotations for the sample counts\n", " for idx, count in enumerate(sample_counts):\n", - " axs[1].text(idx, 0.21, f\"n={count}\", ha=\"center\", va=\"bottom\", fontsize=13)\n", + " axs[1].text(idx, 0.21, f'n={count}', ha='center', va='bottom', fontsize=13)\n", "\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_+hist_wop_iqr.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_+hist_wop_iqr.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20597,37 +20596,37 @@ "\n", " # Bin the data based on 'tanimoto3'\n", " binned_data = pd.cut(\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)][\"tanimoto3\"],\n", + " C[(C['Precursor_type'] == '[M+H]+') & C['Dataset'].isin(datasets)]['tanimoto3'],\n", " bins=[x / 10.0 for x in range(2, 11, 1)],\n", " )\n", "\n", " # Group by both binned_data and 'Dataset', then count unique 'group_id'\n", " grouped_data = (\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & C[\"Dataset\"].isin(datasets)]\n", - " .groupby([binned_data, \"Dataset\"])[\"group_id\"]\n", + " C[(C['Precursor_type'] == '[M+H]+') & C['Dataset'].isin(datasets)]\n", + " .groupby([binned_data, 'Dataset'])['group_id']\n", " .nunique()\n", - " .reset_index(name=\"unique_group_ids\")\n", + " .reset_index(name='unique_group_ids')\n", " )\n", "\n", " sns.barplot(\n", - " x=\"tanimoto3\",\n", - " y=\"unique_group_ids\",\n", - " hue=\"Dataset\",\n", + " x='tanimoto3',\n", + " y='unique_group_ids',\n", + " hue='Dataset',\n", " data=grouped_data,\n", - " palette=\"tab10\",\n", + " palette='tab10',\n", " hue_order=datasets,\n", - " edgecolor=\"black\",\n", + " edgecolor='black',\n", " gap=0.15,\n", " )\n", "\n", - " plt.xlabel(\"Tanimoto similarity\")\n", - " plt.ylabel(\"Number of compounds\")\n", + " plt.xlabel('Tanimoto similarity')\n", + " plt.ylabel('Number of compounds')\n", " # plt.xticks(rotation=45)\n", - " plt.legend(title=\"Dataset\", loc=\"upper right\")\n", + " plt.legend(title='Dataset', loc='upper right')\n", " plt.tight_layout()\n", - " axs.tick_params(axis=\"both\", labelsize=13)\n", - " plt.rc(\"axes\", labelsize=14)\n", - " plt.rc(\"legend\", fontsize=14)\n", + " axs.tick_params(axis='both', labelsize=13)\n", + " plt.rc('axes', labelsize=14)\n", + " plt.rc('legend', fontsize=14)\n", "\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_distribution.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", " # fig.savefig(f\"{home}/images/paper/tanimoto_distribution.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20652,31 +20651,31 @@ "source": [ "if calc_tanimoto:\n", " medians = (\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")]\n", + " C[(C['Precursor_type'] == '[M+H]+') & (C['Dataset'] == 'Test split')]\n", " .groupby(\n", " pd.cut(\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", - " \"tanimoto3\"\n", + " C[(C['Precursor_type'] == '[M+H]+') & (C['Dataset'] == 'Test split')][\n", + " 'tanimoto3'\n", " ],\n", " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", " )\n", - " )[\"spectral_sqrt_cosine\"]\n", + " )['spectral_sqrt_cosine']\n", " .median()\n", " )\n", - " print(f\"Fiora: {medians.min() / medians.max()}\")\n", + " print(f'Fiora: {medians.min() / medians.max()}')\n", " medians = (\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")]\n", + " C[(C['Precursor_type'] == '[M+H]+') & (C['Dataset'] == 'Test split')]\n", " .groupby(\n", " pd.cut(\n", - " C[(C[\"Precursor_type\"] == \"[M+H]+\") & (C[\"Dataset\"] == \"Test split\")][\n", - " \"tanimoto3\"\n", + " C[(C['Precursor_type'] == '[M+H]+') & (C['Dataset'] == 'Test split')][\n", + " 'tanimoto3'\n", " ],\n", " bins=[x / 10.0 for x in list(range(2, 11, 1))],\n", " )\n", - " )[\"ice_sqrt_cosine\"]\n", + " )['ice_sqrt_cosine']\n", " .median()\n", " )\n", - " print(f\"ICEBERG: {medians.min() / medians.max()}\")" + " print(f'ICEBERG: {medians.min() / medians.max()}')" ] }, { @@ -20705,17 +20704,17 @@ "sns.histplot(\n", " C,\n", " ax=ax,\n", - " x=\"spectral_sqrt_cosine\",\n", - " hue=\"Dataset\",\n", + " x='spectral_sqrt_cosine',\n", + " hue='Dataset',\n", " linewidth=1,\n", - " multiple=\"stack\",\n", - " edgecolor=\"black\",\n", + " multiple='stack',\n", + " edgecolor='black',\n", ")\n", - "ax.set_xlabel(\"Spectral cosine similarity\")\n", + "ax.set_xlabel('Spectral cosine similarity')\n", "\n", - "plt.rc(\"axes\", labelsize=14)\n", - "plt.rc(\"legend\", fontsize=14)\n", - "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "plt.rc('axes', labelsize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "ax.tick_params(axis='both', which='major', labelsize=13)\n", "# fig.savefig(f\"{home}/images/paper/histplot_cosine.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/histplot_cosine.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/histplot_cosine.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20753,7 +20752,7 @@ " 2,\n", " 1,\n", " figsize=(12, 12),\n", - " gridspec_kw={\"height_ratios\": [1, 1]},\n", + " gridspec_kw={'height_ratios': [1, 1]},\n", " sharex=True,\n", " sharey=True,\n", ")\n", @@ -20761,24 +20760,24 @@ "\n", "C = pd.concat([df_test, df_msnlib_test, df_cas, df_cas22], ignore_index=True)\n", "\n", - "sns.kdeplot(C, ax=axs[0], x=\"spectral_sqrt_cosine\", color=bluepink[0], linewidth=2)\n", - "sns.kdeplot(C, ax=axs[0], x=\"ice_sqrt_cosine\", color=bluepink[1], linewidth=2)\n", - "sns.kdeplot(C, ax=axs[0], x=\"cfm_sqrt_cosine\", color=\"gray\", linewidth=2)\n", - "axs[0].legend(title=\"Default\", labels=[\"Fiora\", \"ICEBERG\", \"CFM-ID\"], loc=\"upper left\")\n", + "sns.kdeplot(C, ax=axs[0], x='spectral_sqrt_cosine', color=bluepink[0], linewidth=2)\n", + "sns.kdeplot(C, ax=axs[0], x='ice_sqrt_cosine', color=bluepink[1], linewidth=2)\n", + "sns.kdeplot(C, ax=axs[0], x='cfm_sqrt_cosine', color='gray', linewidth=2)\n", + "axs[0].legend(title='Default', labels=['Fiora', 'ICEBERG', 'CFM-ID'], loc='upper left')\n", "\n", "sns.kdeplot(\n", - " C, ax=axs[1], x=\"spectral_sqrt_cosine_wo_prec\", color=bluepink[0], linewidth=2\n", + " C, ax=axs[1], x='spectral_sqrt_cosine_wo_prec', color=bluepink[0], linewidth=2\n", ")\n", - "sns.kdeplot(C, ax=axs[1], x=\"ice_sqrt_cosine_wo_prec\", color=bluepink[1], linewidth=2)\n", - "sns.kdeplot(C, ax=axs[1], x=\"cfm_sqrt_cosine_wo_prec\", color=\"gray\", linewidth=2)\n", + "sns.kdeplot(C, ax=axs[1], x='ice_sqrt_cosine_wo_prec', color=bluepink[1], linewidth=2)\n", + "sns.kdeplot(C, ax=axs[1], x='cfm_sqrt_cosine_wo_prec', color='gray', linewidth=2)\n", "axs[1].legend(\n", - " title=\"Without precursor\", labels=[\"Fiora\", \"ICEBERG\", \"CFM-ID\"], loc=\"upper left\"\n", + " title='Without precursor', labels=['Fiora', 'ICEBERG', 'CFM-ID'], loc='upper left'\n", ")\n", - "axs[1].set_xlabel(\"Spectral cosine similarity\")\n", + "axs[1].set_xlabel('Spectral cosine similarity')\n", "\n", - "plt.rc(\"axes\", labelsize=14)\n", - "plt.rc(\"legend\", fontsize=14)\n", - "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "plt.rc('axes', labelsize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "ax.tick_params(axis='both', which='major', labelsize=13)\n", "# fig.savefig(f\"{home}/images/paper/kdeplots_cosine.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/kdeplots_cosine.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/kdeplots_cosine.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -20825,18 +20824,18 @@ } ], "source": [ - "score = \"spectral_sqrt_cosine\"\n", - "score_w = score + \"_wo_prec\"\n", + "score = 'spectral_sqrt_cosine'\n", + "score_w = score + '_wo_prec'\n", "\n", - "for dataset in [\"Test split\", \"MSnLib\", \"CASMI 16\", \"CASMI 22\"]:\n", - " for p in C[\"Precursor_type\"].unique()[:2]:\n", - " Cx = C[(C[\"Dataset\"] == dataset) & (C[\"Precursor_type\"] == p)]\n", + "for dataset in ['Test split', 'MSnLib', 'CASMI 16', 'CASMI 22']:\n", + " for p in C['Precursor_type'].unique()[:2]:\n", + " Cx = C[(C['Dataset'] == dataset) & (C['Precursor_type'] == p)]\n", " dif = Cx[score].fillna(0.0) - Cx[score_w].fillna(0.0)\n", " abs_avg_dif = np.mean(dif)\n", " rel_avg_dif = np.mean(dif / Cx[score].fillna(0.0))\n", "\n", " print(\n", - " f\"Cosine loss from precursor removal (for {dataset} {p}): \\t{abs_avg_dif:.2f} ({rel_avg_dif * 100:2.1f}%)\"\n", + " f'Cosine loss from precursor removal (for {dataset} {p}): \\t{abs_avg_dif:.2f} ({rel_avg_dif * 100:2.1f}%)'\n", " )" ] }, @@ -20868,19 +20867,19 @@ } ], "source": [ - "print(\"Avg loss:\", np.mean(C[score].fillna(0.0) - C[score_w].fillna(0.0)))\n", + "print('Avg loss:', np.mean(C[score].fillna(0.0) - C[score_w].fillna(0.0)))\n", "print(\n", - " \"Avg pos loss:\",\n", + " 'Avg pos loss:',\n", " np.mean(\n", - " C[C[\"Precursor_type\"] == \"[M+H]+\"][score].fillna(0.0)\n", - " - C[C[\"Precursor_type\"] == \"[M+H]+\"][score_w].fillna(0.0)\n", + " C[C['Precursor_type'] == '[M+H]+'][score].fillna(0.0)\n", + " - C[C['Precursor_type'] == '[M+H]+'][score_w].fillna(0.0)\n", " ),\n", ")\n", "print(\n", - " \"Avg neg loss:\",\n", + " 'Avg neg loss:',\n", " np.mean(\n", - " C[C[\"Precursor_type\"] == \"[M-H]-\"][score].fillna(0.0)\n", - " - C[C[\"Precursor_type\"] == \"[M-H]-\"][score_w].fillna(0.0)\n", + " C[C['Precursor_type'] == '[M-H]-'][score].fillna(0.0)\n", + " - C[C['Precursor_type'] == '[M-H]-'][score_w].fillna(0.0)\n", " ),\n", ")" ] @@ -20912,65 +20911,65 @@ "from fiora.visualization.define_colors import set_light_theme, tri_palette\n", "\n", "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + " 2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True\n", ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "set_light_theme()\n", "set_light_theme()\n", "\n", - "df_test_unique = df_test.dropna(subset=[\"RETENTIONTIME\"]).drop_duplicates(\n", - " subset=\"SMILES\", keep=\"first\"\n", + "df_test_unique = df_test.dropna(subset=['RETENTIONTIME']).drop_duplicates(\n", + " subset='SMILES', keep='first'\n", ")\n", "\n", - "plt.rc(\"legend\", fontsize=20)\n", + "plt.rc('legend', fontsize=20)\n", "\n", "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", " ax=axs[0],\n", " data=df_test_unique,\n", - " x=\"RETENTIONTIME\",\n", + " x='RETENTIONTIME',\n", " bw_adjust=0.25,\n", - " palette=[\"gray\"],\n", + " palette=['gray'],\n", " fill=True,\n", - " hue=\"Dataset\",\n", + " hue='Dataset',\n", " alpha=0.7,\n", ") # , multiple=\"stack\") #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", " ax=axs[0],\n", " data=df_test_unique,\n", - " x=\"RETENTIONTIME\",\n", + " x='RETENTIONTIME',\n", " bw_adjust=0.25,\n", - " color=\"#696969\",\n", + " color='#696969',\n", " linewidth=1.7,\n", ") # , multiple=\"stack\") #hue=\"Precursor_type\",\n", "# axs[0].legend(title=\"Dataset\", loc=\"upper right\")\n", "\n", "\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "\n", "\n", "sns.scatterplot(\n", " ax=axs[1],\n", " data=df_test_unique,\n", - " x=\"RETENTIONTIME\",\n", - " y=\"RT_pred\",\n", - " color=\"gray\",\n", + " x='RETENTIONTIME',\n", + " y='RT_pred',\n", + " color='gray',\n", " marker=(4, 1, 0),\n", " s=60,\n", - " edgecolor=\"None\",\n", + " edgecolor='None',\n", ") # , hue=\"library\", palette=tri_palette, style=\"library\", color=\"gray\")\n", - "axs[1].set_ylim([2, df_test_unique[\"RETENTIONTIME\"].max() + 0.5])\n", - "axs[1].set_xlim([2, df_test_unique[\"RETENTIONTIME\"].max() + 0.5])\n", - "axs[1].set_ylabel(\"Predicted retention time (in minutes)\")\n", - "axs[1].set_xlabel(\"Observed retention time (in minutes)\")\n", + "axs[1].set_ylim([2, df_test_unique['RETENTIONTIME'].max() + 0.5])\n", + "axs[1].set_xlim([2, df_test_unique['RETENTIONTIME'].max() + 0.5])\n", + "axs[1].set_ylabel('Predicted retention time (in minutes)')\n", + "axs[1].set_xlabel('Observed retention time (in minutes)')\n", "line = [0, 100]\n", - "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", + "sns.lineplot(ax=axs[1], x=line, y=line, color='black')\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[x + 30 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[x + 30 / 60.0 for x in line], color='black', linestyle='--'\n", ")\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[x - 30 / 60.0 for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[x - 30 / 60.0 for x in line], color='black', linestyle='--'\n", ")\n", "\n", "# sns.lineplot(ax=axs[1], x=line, y=[1.1*x for x in line], color=\"black\", linestyle='--')\n", @@ -20978,24 +20977,24 @@ "\n", "# Text sizes\n", "\n", - "axs[0].tick_params(axis=\"both\", labelsize=16)\n", - "axs[1].tick_params(axis=\"both\", labelsize=16)\n", - "plt.rc(\"axes\", labelsize=20)\n", - "plt.rc(\"legend\", fontsize=20)\n", + "axs[0].tick_params(axis='both', labelsize=16)\n", + "axs[1].tick_params(axis='both', labelsize=16)\n", + "plt.rc('axes', labelsize=20)\n", + "plt.rc('legend', fontsize=20)\n", "\n", "\n", "fig.savefig(\n", - " f\"{home}/images/paper/rt.svg\",\n", - " format=\"svg\",\n", + " f'{home}/images/paper/rt.svg',\n", + " format='svg',\n", " dpi=600,\n", - " bbox_inches=\"tight\",\n", + " bbox_inches='tight',\n", " pad_inches=0.1,\n", ")\n", "fig.savefig(\n", - " f\"{home}/images/paper/rt.pdf\",\n", - " format=\"pdf\",\n", + " f'{home}/images/paper/rt.pdf',\n", + " format='pdf',\n", " dpi=600,\n", - " bbox_inches=\"tight\",\n", + " bbox_inches='tight',\n", " pad_inches=0.1,\n", ")\n", "# fig.savefig(f\"{home}/images/paper/rt.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -21029,16 +21028,16 @@ } ], "source": [ - "print(\"Pearson Corr Coef:\")\n", + "print('Pearson Corr Coef:')\n", "print(\n", - " \"GNN PC\",\n", + " 'GNN PC',\n", " np.corrcoef(\n", - " df_test_unique[\"RETENTIONTIME\"], df_test_unique[\"RT_pred\"].dropna(), dtype=float\n", + " df_test_unique['RETENTIONTIME'], df_test_unique['RT_pred'].dropna(), dtype=float\n", " )[0, 1],\n", ")\n", "print(\n", - " \"GNN R2\",\n", - " r2_score(df_test_unique[\"RETENTIONTIME\"], df_test_unique[\"RT_pred\"].dropna()),\n", + " 'GNN R2',\n", + " r2_score(df_test_unique['RETENTIONTIME'], df_test_unique['RT_pred'].dropna()),\n", ")" ] }, @@ -21048,11 +21047,11 @@ "metadata": {}, "outputs": [], "source": [ - "df_train_rt = df_train[~df_train[\"RETENTIONTIME\"].isna()].drop_duplicates(\n", - " subset=\"group_id\", keep=\"first\"\n", + "df_train_rt = df_train[~df_train['RETENTIONTIME'].isna()].drop_duplicates(\n", + " subset='group_id', keep='first'\n", ")\n", - "df_test_rt = df_test[~df_test[\"RETENTIONTIME\"].isna()].drop_duplicates(\n", - " subset=\"group_id\", keep=\"first\"\n", + "df_test_rt = df_test[~df_test['RETENTIONTIME'].isna()].drop_duplicates(\n", + " subset='group_id', keep='first'\n", ")" ] }, @@ -21082,13 +21081,13 @@ "metadata": {}, "outputs": [], "source": [ - "with open(f\"{home}/data/metabolites/rt/train.logp\") as infile:\n", + "with open(f'{home}/data/metabolites/rt/train.logp') as infile:\n", " lines = infile.readlines()\n", "logps = []\n", "for line in lines[11:]:\n", - " logp = float(line.strip().split(\": \")[1])\n", + " logp = float(line.strip().split(': ')[1])\n", " logps.append(logp)\n", - "df_train_rt[\"logp\"] = logps" + "df_train_rt['logp'] = logps" ] }, { @@ -21097,13 +21096,13 @@ "metadata": {}, "outputs": [], "source": [ - "with open(f\"{home}/data/metabolites/rt/test.logp\") as infile:\n", + "with open(f'{home}/data/metabolites/rt/test.logp') as infile:\n", " lines = infile.readlines()\n", "logps = []\n", "for line in lines[11:]:\n", - " logp = float(line.strip().split(\": \")[1])\n", + " logp = float(line.strip().split(': ')[1])\n", " logps.append(logp)\n", - "df_test_rt[\"logp\"] = logps" + "df_test_rt['logp'] = logps" ] }, { @@ -21112,9 +21111,9 @@ "metadata": {}, "outputs": [], "source": [ - "df_val = df_train[df_train[\"dataset\"] == \"validation\"]\n", - "df_train = df_train[df_train[\"dataset\"] == \"training\"]\n", - "df_train_rt = df_train_rt[df_train_rt[\"dataset\"] == \"training\"]" + "df_val = df_train[df_train['dataset'] == 'validation']\n", + "df_train = df_train[df_train['dataset'] == 'training']\n", + "df_train_rt = df_train_rt[df_train_rt['dataset'] == 'training']" ] }, { @@ -21159,31 +21158,31 @@ ], "source": [ "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", - " df_train_rt[\"logp\"], df_train_rt[\"RETENTIONTIME\"]\n", + " df_train_rt['logp'], df_train_rt['RETENTIONTIME']\n", ")\n", - "print(\"TEST SPLIT:\\n\")\n", - "print(\"Pearson Corr Coef:\")\n", + "print('TEST SPLIT:\\n')\n", + "print('Pearson Corr Coef:')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " np.corrcoef(\n", - " df_test_rt[\"RETENTIONTIME\"], df_test_rt[\"RT_pred\"].dropna(), dtype=float\n", + " df_test_rt['RETENTIONTIME'], df_test_rt['RT_pred'].dropna(), dtype=float\n", " )[0, 1],\n", ")\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " np.corrcoef(\n", - " df_test_rt[\"RETENTIONTIME\"],\n", - " intercept + slope * df_test_rt[\"logp\"].dropna(),\n", + " df_test_rt['RETENTIONTIME'],\n", + " intercept + slope * df_test_rt['logp'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "\n", - "print(\"R2\")\n", - "print(\"GNN\", r2_score(df_test_rt[\"RETENTIONTIME\"], df_test_rt[\"RT_pred\"].dropna()))\n", + "print('R2')\n", + "print('GNN', r2_score(df_test_rt['RETENTIONTIME'], df_test_rt['RT_pred'].dropna()))\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " r2_score(\n", - " df_test_rt[\"RETENTIONTIME\"], intercept + slope * df_test_rt[\"logp\"].dropna()\n", + " df_test_rt['RETENTIONTIME'], intercept + slope * df_test_rt['logp'].dropna()\n", " ),\n", ")" ] @@ -21222,86 +21221,86 @@ "source": [ "# TODO NEXT UP!!\n", "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(12, 14), gridspec_kw={\"height_ratios\": [1, 5]}, sharex=True\n", + " 2, 1, figsize=(12, 14), gridspec_kw={'height_ratios': [1, 5]}, sharex=True\n", ")\n", "plt.subplots_adjust(wspace=0.1, hspace=0.05)\n", "\n", - "df_test_unique = df_test.dropna(subset=[\"CCS\"]).drop_duplicates(\n", - " subset=\"SMILES\", keep=\"first\"\n", + "df_test_unique = df_test.dropna(subset=['CCS']).drop_duplicates(\n", + " subset='SMILES', keep='first'\n", ")\n", - "df_cas22_unique = df_cas22.dropna(subset=[\"CCS\"]).drop_duplicates(\n", - " subset=\"SMILES\", keep=\"first\"\n", + "df_cas22_unique = df_cas22.dropna(subset=['CCS']).drop_duplicates(\n", + " subset='SMILES', keep='first'\n", ") # Note that with more lenient filters, CCS values might be annotated for CASMI 22\n", "\n", "CCS = pd.concat(\n", " [\n", - " df_test_unique[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", - " df_cas[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", - " df_cas22_unique[[\"CCS\", \"CCS_pred\", \"Dataset\"]],\n", + " df_test_unique[['CCS', 'CCS_pred', 'Dataset']],\n", + " df_cas[['CCS', 'CCS_pred', 'Dataset']],\n", + " df_cas22_unique[['CCS', 'CCS_pred', 'Dataset']],\n", " ],\n", " ignore_index=True,\n", - ") #\n", + ")\n", "\n", "\n", "# sns.histplot(ax=axs[0], data=df_val_unique, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", " ax=axs[0],\n", - " data=CCS[CCS[\"Dataset\"] != \"CASMI 22\"],\n", - " x=\"CCS\",\n", + " data=CCS[CCS['Dataset'] != 'CASMI 22'],\n", + " x='CCS',\n", " bw_adjust=0.35,\n", - " color=\"black\",\n", - " multiple=\"stack\",\n", - " hue=\"Dataset\",\n", - " palette=[\"#696969\"] + [\"white\"],\n", + " color='black',\n", + " multiple='stack',\n", + " hue='Dataset',\n", + " palette=['#696969'] + ['white'],\n", " linewidth=1.7,\n", " fill=False,\n", ") # , edgecolor=\"lightgray\") #hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", " ax=axs[0],\n", - " data=CCS[CCS[\"Dataset\"] != \"CASMI 22\"],\n", - " x=\"CCS\",\n", + " data=CCS[CCS['Dataset'] != 'CASMI 22'],\n", + " x='CCS',\n", " bw_adjust=0.35,\n", - " color=\"black\",\n", - " multiple=\"stack\",\n", - " hue=\"Dataset\",\n", - " palette=[\"gray\"] + [tri_palette[1]],\n", + " color='black',\n", + " multiple='stack',\n", + " hue='Dataset',\n", + " palette=['gray'] + [tri_palette[1]],\n", " alpha=0.7,\n", " fill=True,\n", - " edgecolor=\"gray\",\n", + " edgecolor='gray',\n", ") # , edgecolor=\"lightgray\") #hue=\"Precursor_type\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "\n", "\n", "sns.scatterplot(\n", " ax=axs[1],\n", " data=CCS,\n", - " x=\"CCS\",\n", - " y=\"CCS_pred\",\n", - " hue=\"Dataset\",\n", + " x='CCS',\n", + " y='CCS_pred',\n", + " hue='Dataset',\n", " palette=tri_palette,\n", - " style=\"Dataset\",\n", - " markers=[(4, 1, 0), \"v\", \"o\", (4, 0, 45), \"v\", \"D\"],\n", + " style='Dataset',\n", + " markers=[(4, 1, 0), 'v', 'o', (4, 0, 45), 'v', 'D'],\n", " s=35,\n", " linewidth=0.0,\n", ") # , s=50, edgecolor=\"white\")#, linewidth=.0)#, color=\"blue\", edgecolor=\"blue\")#,\n", - "axs[1].set_ylim([df_test_unique[\"CCS\"].min() - 30, df_test_unique[\"CCS\"].max() + 5])\n", - "axs[1].set_xlim([df_test_unique[\"CCS\"].min() - 30, df_test_unique[\"CCS\"].max() + 5])\n", - "axs[1].set_ylabel(\"Predicted CCS\")\n", - "axs[1].set_xlabel(\"Observed CCS\")\n", - "line = [df_test_unique[\"CCS\"].min() - 30, df_test_unique[\"CCS\"].max() + 5]\n", - "sns.lineplot(ax=axs[1], x=line, y=line, color=\"black\")\n", + "axs[1].set_ylim([df_test_unique['CCS'].min() - 30, df_test_unique['CCS'].max() + 5])\n", + "axs[1].set_xlim([df_test_unique['CCS'].min() - 30, df_test_unique['CCS'].max() + 5])\n", + "axs[1].set_ylabel('Predicted CCS')\n", + "axs[1].set_xlabel('Observed CCS')\n", + "line = [df_test_unique['CCS'].min() - 30, df_test_unique['CCS'].max() + 5]\n", + "sns.lineplot(ax=axs[1], x=line, y=line, color='black')\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[1.1 * x for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[1.1 * x for x in line], color='black', linestyle='--'\n", ")\n", "sns.lineplot(\n", - " ax=axs[1], x=line, y=[0.9 * x for x in line], color=\"black\", linestyle=\"--\"\n", + " ax=axs[1], x=line, y=[0.9 * x for x in line], color='black', linestyle='--'\n", ")\n", "\n", - "axs[0].tick_params(axis=\"both\", labelsize=16)\n", - "axs[1].tick_params(axis=\"both\", labelsize=16)\n", - "plt.rc(\"axes\", labelsize=20)\n", - "plt.rc(\"legend\", fontsize=20)\n", + "axs[0].tick_params(axis='both', labelsize=16)\n", + "axs[1].tick_params(axis='both', labelsize=16)\n", + "plt.rc('axes', labelsize=20)\n", + "plt.rc('legend', fontsize=20)\n", "\n", "# fig.savefig(f\"{home}/images/paper/ccs.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/ccs.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -21350,77 +21349,77 @@ ], "source": [ "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n", - " df_train.dropna(subset=[\"CCS\"])[\"PRECURSORMZ\"],\n", - " df_train.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_train.dropna(subset=['CCS'])['PRECURSORMZ'],\n", + " df_train.dropna(subset=['CCS'])['CCS'],\n", ")\n", - "print(\"TEST SPLIT:\\n\")\n", - "print(\"Pearson Corr Coef:\")\n", + "print('TEST SPLIT:\\n')\n", + "print('Pearson Corr Coef:')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " np.corrcoef(\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_test_unique.dropna(subset=['CCS'])['CCS'],\n", + " df_test_unique.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " np.corrcoef(\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_test_unique.dropna(subset=['CCS'])['CCS'],\n", " intercept\n", - " + slope * df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " + slope * df_test_unique.dropna(subset=['CCS'])['PrecursorMZ'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "\n", - "print(\"R2\")\n", + "print('R2')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " r2_score(\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_test_unique.dropna(subset=['CCS'])['CCS'],\n", + " df_test_unique.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " ),\n", ")\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " r2_score(\n", - " df_test_unique.dropna(subset=[\"CCS\"])[\"CCS\"],\n", + " df_test_unique.dropna(subset=['CCS'])['CCS'],\n", " intercept\n", - " + slope * df_test_unique.dropna(subset=[\"CCS\"])[\"PrecursorMZ\"].dropna(),\n", + " + slope * df_test_unique.dropna(subset=['CCS'])['PrecursorMZ'].dropna(),\n", " ),\n", ")\n", "\n", - "print(\"---------------\\n\\nCASMI-16:\\n\")\n", - "print(\"Pearson Corr Coef:\")\n", + "print('---------------\\n\\nCASMI-16:\\n')\n", + "print('Pearson Corr Coef:')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " np.corrcoef(\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_cas.dropna(subset=['CCS'])['CCS'],\n", + " df_cas.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " np.corrcoef(\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " intercept + slope * df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " df_cas.dropna(subset=['CCS'])['CCS'],\n", + " intercept + slope * df_cas.dropna(subset=['CCS'])['PRECURSOR_MZ'].dropna(),\n", " dtype=float,\n", " )[0, 1],\n", ")\n", - "print(\"R2\")\n", + "print('R2')\n", "print(\n", - " \"GNN\",\n", + " 'GNN',\n", " r2_score(\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS_pred\"].dropna(),\n", + " df_cas.dropna(subset=['CCS'])['CCS'],\n", + " df_cas.dropna(subset=['CCS'])['CCS_pred'].dropna(),\n", " ),\n", ")\n", "print(\n", - " \"LR \",\n", + " 'LR ',\n", " r2_score(\n", - " df_cas.dropna(subset=[\"CCS\"])[\"CCS\"],\n", - " intercept + slope * df_cas.dropna(subset=[\"CCS\"])[\"PRECURSOR_MZ\"].dropna(),\n", + " df_cas.dropna(subset=['CCS'])['CCS'],\n", + " intercept + slope * df_cas.dropna(subset=['CCS'])['PRECURSOR_MZ'].dropna(),\n", " ),\n", ")" ] @@ -21432,9 +21431,9 @@ "outputs": [], "source": [ "# Load coverage into dataframe\n", - "df_test[\"coverage\"] = df_test[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - "df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - "df_cas22[\"coverage\"] = df_cas22[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])" + "df_test['coverage'] = df_test['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", + "df_cas['coverage'] = df_cas['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", + "df_cas22['coverage'] = df_cas22['Metabolite'].apply(lambda x: x.match_stats['coverage'])" ] }, { @@ -21460,15 +21459,15 @@ "sns.histplot(\n", " ax=ax,\n", " data=CAT,\n", - " x=\"CE\",\n", - " hue=\"Dataset\",\n", - " multiple=\"stack\",\n", + " x='CE',\n", + " hue='Dataset',\n", + " multiple='stack',\n", " palette=tri_palette,\n", - " stat=\"density\",\n", + " stat='density',\n", " common_norm=False,\n", " kde=False,\n", ")\n", - "plt.xlabel(\"Collision energy\")\n", + "plt.xlabel('Collision energy')\n", "plt.show()" ] }, @@ -21515,7 +21514,7 @@ } ], "source": [ - "sns.palplot(sns.color_palette(\"colorblind\"))\n", + "sns.palplot(sns.color_palette('colorblind'))\n", "plt.show()" ] }, @@ -21546,25 +21545,25 @@ "source": [ "set_light_theme()\n", "scores = [\n", - " \"spectral_cosine\",\n", - " \"spectral_sqrt_cosine\",\n", - " \"spectral_sqrt_cosine_wo_prec\",\n", - " \"spectral_refl_cosine\",\n", - " \"steins_cosine\",\n", + " 'spectral_cosine',\n", + " 'spectral_sqrt_cosine',\n", + " 'spectral_sqrt_cosine_wo_prec',\n", + " 'spectral_refl_cosine',\n", + " 'steins_cosine',\n", "]\n", - "biases = [s.replace(\"cosine\", \"bias\") for s in scores]\n", + "biases = [s.replace('cosine', 'bias') for s in scores]\n", "# labels = [\"cosine\", r\"cosine $\\sqrt{int}$ \", r\"$\\sqrt{i}$ cosine w/o precursor\", r\"$\\sqrt{i}$ reflection cosine\", r\"$\\sqrt{i}$ mass-weighted cosine\"]\n", "\n", "# labels = [\"cosine\", \"sqrt cosine\", \"sqrt cosine w/o precursor\", \"sqrt reflection cosine\", \"sqrt mass-weighted cosine\"]\n", "labels = [\n", - " \"Standard\\ncosine\",\n", - " \"Square root\\nintensities\",\n", - " \"Square root\\nintensities\\nw/o precursor\",\n", - " \"Square root\\nintensities\\nreflection score\",\n", - " \"Square root\\nintensities\\nscaled by m/z\",\n", + " 'Standard\\ncosine',\n", + " 'Square root\\nintensities',\n", + " 'Square root\\nintensities\\nw/o precursor',\n", + " 'Square root\\nintensities\\nreflection score',\n", + " 'Square root\\nintensities\\nscaled by m/z',\n", "]\n", - "S = df_test.melt(id_vars=\"Name\", var_name=\"Score\", value_vars=scores)\n", - "B = df_test.melt(id_vars=\"Name\", var_name=\"Score\", value_vars=biases)\n", + "S = df_test.melt(id_vars='Name', var_name='Score', value_vars=scores)\n", + "B = df_test.melt(id_vars='Name', var_name='Score', value_vars=biases)\n", "\n", "\n", "fig, axs = plt.subplots(2, 1, figsize=(10, 10), sharex=False)\n", @@ -21572,62 +21571,62 @@ "\n", "\n", "highlight_2 = [\n", - " sns.color_palette(\"colorblind\")[7],\n", - " sns.color_palette(\"colorblind\")[9],\n", - " sns.color_palette(\"colorblind\")[7],\n", - " sns.color_palette(\"colorblind\")[7],\n", - " sns.color_palette(\"colorblind\")[7],\n", + " sns.color_palette('colorblind')[7],\n", + " sns.color_palette('colorblind')[9],\n", + " sns.color_palette('colorblind')[7],\n", + " sns.color_palette('colorblind')[7],\n", + " sns.color_palette('colorblind')[7],\n", "]\n", "sns.boxplot(\n", " ax=axs[0],\n", " data=S,\n", - " y=\"value\",\n", - " x=\"Score\",\n", + " y='value',\n", + " x='Score',\n", " order=scores,\n", - " hue=\"Score\",\n", + " hue='Score',\n", " palette=highlight_2,\n", " showfliers=False,\n", ")\n", "sns.boxplot(\n", " ax=axs[1],\n", " data=B,\n", - " y=\"value\",\n", - " x=\"Score\",\n", + " y='value',\n", + " x='Score',\n", " order=biases,\n", - " hue=\"Score\",\n", + " hue='Score',\n", " palette=highlight_2,\n", " showfliers=False,\n", ")\n", - "axs[0].set_xticklabels(\"\")\n", - "axs[0].set_xlabel(\"\")\n", - "axs[0].set_ylabel(\"Similarity\", fontsize=14)\n", - "axs[1].set_xlabel(\"\")\n", + "axs[0].set_xticklabels('')\n", + "axs[0].set_xlabel('')\n", + "axs[0].set_ylabel('Similarity', fontsize=14)\n", + "axs[1].set_xlabel('')\n", "axs[1].set_xticklabels(labels)\n", - "axs[1].set_ylabel(\"Bias\", fontsize=14)\n", + "axs[1].set_ylabel('Bias', fontsize=14)\n", "# axs[1].set_ylim([0, 1])\n", "# axs[1].axhline(y=B.groupby('Score')['value'].median().min(), xmin=0, xmax=10, color=\"red\", linestyle=\"--\")\n", "\n", - "plt.rc(\"axes\", labelsize=14)\n", - "plt.rc(\"legend\", fontsize=14)\n", - "axs[0].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", - "axs[1].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "plt.rc('axes', labelsize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "axs[0].tick_params(axis='both', which='major', labelsize=13)\n", + "axs[1].tick_params(axis='both', which='major', labelsize=13)\n", "axs[0].text(\n", " 0.02,\n", " 0.025,\n", - " f\"n={df_test.shape[0]} for all categories\",\n", + " f'n={df_test.shape[0]} for all categories',\n", " transform=axs[0].transAxes,\n", " fontsize=13,\n", - " va=\"bottom\",\n", - " ha=\"left\",\n", + " va='bottom',\n", + " ha='left',\n", ")\n", "axs[1].text(\n", " 0.02,\n", " 0.025,\n", - " f\"n={df_test.shape[0]} for all categories\",\n", + " f'n={df_test.shape[0]} for all categories',\n", " transform=axs[1].transAxes,\n", " fontsize=13,\n", - " va=\"bottom\",\n", - " ha=\"left\",\n", + " va='bottom',\n", + " ha='left',\n", ")\n", "\n", "\n", @@ -21677,7 +21676,7 @@ } ], "source": [ - "df_cas[\"spectral_sqrt_cosine\"]" + "df_cas['spectral_sqrt_cosine']" ] }, { @@ -21717,17 +21716,17 @@ "set_light_theme()\n", "fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n", "score_columns = [\n", - " \"spectral_sqrt_cosine_20\",\n", - " \"spectral_sqrt_cosine_35\",\n", - " \"spectral_sqrt_cosine_50\",\n", - " \"spectral_sqrt_cosine\",\n", + " 'spectral_sqrt_cosine_20',\n", + " 'spectral_sqrt_cosine_35',\n", + " 'spectral_sqrt_cosine_50',\n", + " 'spectral_sqrt_cosine',\n", "]\n", "\n", "df_filtered = df_cas[score_columns]\n", "\n", - "df_melted = df_filtered.melt(var_name=\"Score_Type\")\n", - "custom_labels = [\"20\", \"35\", \"50\", \"20/35/50\"]\n", - "magma = sns.color_palette(\"magma_r\", 4)\n", + "df_melted = df_filtered.melt(var_name='Score_Type')\n", + "custom_labels = ['20', '35', '50', '20/35/50']\n", + "magma = sns.color_palette('magma_r', 4)\n", "adjusted_last_color = sns.utils.set_hls_values(\n", " magma[1], l=min(1.0, magma[-1][2] * 1.2)\n", ") # Increase lightness by 20%\n", @@ -21738,40 +21737,40 @@ "ax = sns.boxplot(\n", " ax=ax,\n", " data=df_melted,\n", - " x=\"Score_Type\",\n", - " y=\"value\",\n", - " hue=\"Score_Type\",\n", + " x='Score_Type',\n", + " y='value',\n", + " hue='Score_Type',\n", " dodge=False,\n", " showfliers=False,\n", " palette=magma[:3] + [magma[1]],\n", " linewidth=2,\n", ")\n", "bars = ax.patches\n", - "bars[-1].set_hatch(\".\")\n", + "bars[-1].set_hatch('.')\n", "# bars[-1].set_hatch_linewidth(2)\n", "\n", "ax.axhline(\n", - " df_filtered[\"spectral_sqrt_cosine\"].median(),\n", - " color=\"black\",\n", + " df_filtered['spectral_sqrt_cosine'].median(),\n", + " color='black',\n", " linewidth=2,\n", - " linestyle=\"dotted\",\n", - " label=\"Median of spectral_sqrt_cosine\",\n", + " linestyle='dotted',\n", + " label='Median of spectral_sqrt_cosine',\n", ")\n", - "ax.set_xlabel(\"NCE\")\n", - "ax.set_ylabel(\"Cosine similarity\")\n", + "ax.set_xlabel('NCE')\n", + "ax.set_ylabel('Cosine similarity')\n", "ax.set_xticklabels(custom_labels)\n", "\n", - "plt.rc(\"axes\", labelsize=14)\n", - "plt.rc(\"legend\", fontsize=14)\n", - "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", + "plt.rc('axes', labelsize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "ax.tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", "ax.text(\n", " 0.02,\n", " 0.02,\n", - " f\"n={df_filtered.shape[0]} for all data points\",\n", + " f'n={df_filtered.shape[0]} for all data points',\n", " transform=ax.transAxes,\n", " fontsize=13,\n", - " va=\"bottom\",\n", - " ha=\"left\",\n", + " va='bottom',\n", + " ha='left',\n", ")\n", "\n", "\n", @@ -21805,7 +21804,7 @@ } ], "source": [ - "df_melted.groupby(\"Score_Type\")[\"value\"].median()" + "df_melted.groupby('Score_Type')['value'].median()" ] }, { @@ -21896,24 +21895,24 @@ " # d[\"mz\"] = d[\"mz\"][high_mz_idx.tolist()]\n", " # d[\"intensity\"] = d[\"intensity\"][high_mz_idx.tolist()]\n", "\n", - " order = np.argsort(peaks[\"intensity\"])[::-1]\n", - " value = fraction * np.sum(peaks[\"intensity\"])\n", + " order = np.argsort(peaks['intensity'])[::-1]\n", + " value = fraction * np.sum(peaks['intensity'])\n", " num_of_relevant_peaks = min(\n", - " max_peaks - 1, np.argmax(np.cumsum(np.array(peaks[\"intensity\"])[order]) > value)\n", + " max_peaks - 1, np.argmax(np.cumsum(np.array(peaks['intensity'])[order]) > value)\n", " )\n", " indices = order[: num_of_relevant_peaks + 1]\n", "\n", " d = {\n", - " \"mz\": np.array(peaks[\"mz\"])[indices].tolist(),\n", - " \"intensity\": np.array(peaks[\"intensity\"])[indices].tolist(),\n", + " 'mz': np.array(peaks['mz'])[indices].tolist(),\n", + " 'intensity': np.array(peaks['intensity'])[indices].tolist(),\n", " }\n", "\n", " return d\n", " # return d\n", "\n", "\n", - "print(df_cas[\"peaks\"].iloc[0])\n", - "p = filter_peaks(df_cas[\"peaks\"].iloc[0], max_peaks=10)\n", + "print(df_cas['peaks'].iloc[0])\n", + "p = filter_peaks(df_cas['peaks'].iloc[0], max_peaks=10)\n", "p" ] }, @@ -21923,29 +21922,29 @@ "metadata": {}, "outputs": [], "source": [ - "df_cas22[\"num_peaks\"] = df_cas22[\"peaks\"].apply(lambda x: len(x[\"mz\"]))\n", - "df_cas[\"num_peaks\"] = df_cas[\"peaks\"].apply(lambda x: len(x[\"mz\"]))\n", - "df_test[\"num_peaks\"] = df_test[\"peaks\"].apply(lambda x: len(x[\"mz\"]))\n", - "df_msnlib_test[\"num_peaks\"] = df_msnlib_test[\"peaks\"].apply(lambda x: len(x[\"mz\"]))\n", - "df_cas22[\"filtered_peaks\"] = df_cas22[\"peaks\"].apply(\n", + "df_cas22['num_peaks'] = df_cas22['peaks'].apply(lambda x: len(x['mz']))\n", + "df_cas['num_peaks'] = df_cas['peaks'].apply(lambda x: len(x['mz']))\n", + "df_test['num_peaks'] = df_test['peaks'].apply(lambda x: len(x['mz']))\n", + "df_msnlib_test['num_peaks'] = df_msnlib_test['peaks'].apply(lambda x: len(x['mz']))\n", + "df_cas22['filtered_peaks'] = df_cas22['peaks'].apply(\n", " lambda x: filter_peaks(x, max_peaks=20, fraction=0.8)\n", ")\n", - "df_cas[\"filtered_peaks\"] = df_cas[\"peaks\"].apply(\n", + "df_cas['filtered_peaks'] = df_cas['peaks'].apply(\n", " lambda x: filter_peaks(x, max_peaks=20, fraction=0.8)\n", ")\n", - "df_test[\"filtered_peaks\"] = df_test[\"peaks\"].apply(\n", + "df_test['filtered_peaks'] = df_test['peaks'].apply(\n", " lambda x: filter_peaks(x, max_peaks=20, fraction=0.8)\n", ")\n", - "df_msnlib_test[\"filtered_peaks\"] = df_msnlib_test[\"peaks\"].apply(\n", + "df_msnlib_test['filtered_peaks'] = df_msnlib_test['peaks'].apply(\n", " lambda x: filter_peaks(x, max_peaks=20, fraction=0.8)\n", ")\n", - "df_cas22[\"num_filtered_peaks\"] = df_cas22[\"filtered_peaks\"].apply(\n", - " lambda x: len(x[\"mz\"])\n", + "df_cas22['num_filtered_peaks'] = df_cas22['filtered_peaks'].apply(\n", + " lambda x: len(x['mz'])\n", ")\n", - "df_cas[\"num_filtered_peaks\"] = df_cas[\"filtered_peaks\"].apply(lambda x: len(x[\"mz\"]))\n", - "df_test[\"num_filtered_peaks\"] = df_test[\"filtered_peaks\"].apply(lambda x: len(x[\"mz\"]))\n", - "df_msnlib_test[\"num_filtered_peaks\"] = df_msnlib_test[\"filtered_peaks\"].apply(\n", - " lambda x: len(x[\"mz\"])\n", + "df_cas['num_filtered_peaks'] = df_cas['filtered_peaks'].apply(lambda x: len(x['mz']))\n", + "df_test['num_filtered_peaks'] = df_test['filtered_peaks'].apply(lambda x: len(x['mz']))\n", + "df_msnlib_test['num_filtered_peaks'] = df_msnlib_test['filtered_peaks'].apply(\n", + " lambda x: len(x['mz'])\n", ")\n", "CAT = pd.concat([df_test, df_msnlib_test, df_cas, df_cas22])" ] @@ -21980,12 +21979,12 @@ ], "source": [ "reset_matplotlib()\n", - "fig, axs = plt.subplots(4, 4, figsize=(12, 8), sharex=\"col\", sharey=\"col\")\n", + "fig, axs = plt.subplots(4, 4, figsize=(12, 8), sharex='col', sharey='col')\n", "plt.subplots_adjust(\n", " hspace=0.1, wspace=0.20, right=0.95\n", ") # , top=0.94, bottom=0.12, right=0.97, left=0)\n", "\n", - "dataset_names = [\"Test split\", \"MSnLib\", \"CASMI 16\", \"CASMI 22\"]\n", + "dataset_names = ['Test split', 'MSnLib', 'CASMI 16', 'CASMI 22']\n", "\n", "\n", "# Loop through each row and set row labels\n", @@ -21994,71 +21993,71 @@ " dataset_name,\n", " rotation=90,\n", " labelpad=10,\n", - " ha=\"center\",\n", - " va=\"center\",\n", + " ha='center',\n", + " va='center',\n", " fontsize=12,\n", - " fontweight=\"bold\",\n", + " fontweight='bold',\n", " )\n", " # Column 1\n", " sns.histplot(\n", " ax=axs[i, 0],\n", - " data=CAT[CAT[\"Dataset\"] == dataset_name],\n", - " x=\"CE\",\n", + " data=CAT[CAT['Dataset'] == dataset_name],\n", + " x='CE',\n", " binwidth=5,\n", - " stat=\"percent\",\n", + " stat='percent',\n", " )\n", " axs[i, 0].set_xlim(0, 100)\n", " axs[i, 0].set_xticks([0, 25, 50, 75, 100])\n", - " axs[i, 0].set_xlabel(\"Collision energy (eV)\", fontsize=12)\n", - " axs[i, 0].tick_params(axis=\"both\", which=\"major\", labelsize=11)\n", + " axs[i, 0].set_xlabel('Collision energy (eV)', fontsize=12)\n", + " axs[i, 0].tick_params(axis='both', which='major', labelsize=11)\n", "\n", " # Column 2\n", " sns.histplot(\n", " ax=axs[i, 1],\n", - " data=CAT[CAT[\"Dataset\"] == dataset_name],\n", - " x=\"num_peaks\",\n", + " data=CAT[CAT['Dataset'] == dataset_name],\n", + " x='num_peaks',\n", " binwidth=10,\n", - " stat=\"percent\",\n", + " stat='percent',\n", " )\n", " axs[i, 1].set_xticks(list(range(0, 260, 50)))\n", " axs[i, 1].set_xlim(0, 250)\n", - " axs[i, 1].set_xlabel(\"Number of peaks\", fontsize=12)\n", - " axs[i, 1].set_ylabel(\"\")\n", - " axs[i, 1].tick_params(axis=\"both\", which=\"major\", labelsize=11)\n", + " axs[i, 1].set_xlabel('Number of peaks', fontsize=12)\n", + " axs[i, 1].set_ylabel('')\n", + " axs[i, 1].tick_params(axis='both', which='major', labelsize=11)\n", "\n", " # Column 3\n", " mz = [\n", " item\n", - " for sublist in CAT[CAT[\"Dataset\"] == dataset_name][\"peaks\"].apply(\n", - " lambda x: x[\"mz\"]\n", + " for sublist in CAT[CAT['Dataset'] == dataset_name]['peaks'].apply(\n", + " lambda x: x['mz']\n", " )\n", " for item in sublist\n", " ]\n", - " sns.histplot(mz, ax=axs[i, 2], binwidth=10, stat=\"percent\")\n", + " sns.histplot(mz, ax=axs[i, 2], binwidth=10, stat='percent')\n", " axs[i, 2].set_xticks(list(range(0, 550, 100)))\n", " axs[i, 2].set_xlim(0, 500)\n", - " axs[i, 2].set_xlabel(\"Peak m/z\", fontsize=12)\n", - " axs[i, 2].set_ylabel(\"\")\n", - " axs[i, 2].axvline(x=125, color=\"black\", linestyle=\"--\", linewidth=1)\n", - " axs[i, 2].tick_params(axis=\"both\", which=\"major\", labelsize=11)\n", + " axs[i, 2].set_xlabel('Peak m/z', fontsize=12)\n", + " axs[i, 2].set_ylabel('')\n", + " axs[i, 2].axvline(x=125, color='black', linestyle='--', linewidth=1)\n", + " axs[i, 2].tick_params(axis='both', which='major', labelsize=11)\n", "\n", " # Column 4\n", " sns.histplot(\n", " ax=axs[i, 3],\n", - " data=CAT[CAT[\"Dataset\"] == dataset_name],\n", - " x=\"num_filtered_peaks\",\n", - " hue=\"Precursor_type\",\n", + " data=CAT[CAT['Dataset'] == dataset_name],\n", + " x='num_filtered_peaks',\n", + " hue='Precursor_type',\n", " binwidth=1,\n", - " stat=\"percent\",\n", - " multiple=\"dodge\",\n", - " hue_order=[\"[M+H]+\", \"[M-H]-\"],\n", + " stat='percent',\n", + " multiple='dodge',\n", + " hue_order=['[M+H]+', '[M-H]-'],\n", " )\n", " axs[i, 3].set_xlim(1, 21)\n", " if i > 0:\n", " axs[i, 3].legend().set_visible(False)\n", - " axs[i, 3].set_xlabel(\"Number of peaks\\nexplaining 80% of intensity\", fontsize=12)\n", - " axs[i, 3].set_ylabel(\"\")\n", - " axs[i, 3].tick_params(axis=\"both\", which=\"major\", labelsize=11)\n", + " axs[i, 3].set_xlabel('Number of peaks\\nexplaining 80% of intensity', fontsize=12)\n", + " axs[i, 3].set_ylabel('')\n", + " axs[i, 3].tick_params(axis='both', which='major', labelsize=11)\n", "\n", "\n", "# fig.savefig(f\"{home}/images/paper/data_hists.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -22085,8 +22084,8 @@ "source": [ "from fiora.MS.spectral_scores import spectral_cosine\n", "\n", - "df_cas22[\"filtered_cosine\"] = df_cas22.apply(\n", - " lambda x: spectral_cosine(x[\"sim_peaks\"], x[\"filtered_peaks\"], transform=np.sqrt),\n", + "df_cas22['filtered_cosine'] = df_cas22.apply(\n", + " lambda x: spectral_cosine(x['sim_peaks'], x['filtered_peaks'], transform=np.sqrt),\n", " axis=1,\n", ")" ] @@ -22108,7 +22107,7 @@ } ], "source": [ - "np.mean(df_cas22[\"num_peak_matches_filtered\"] / df_cas22[\"num_peaks_filtered\"])" + "np.mean(df_cas22['num_peak_matches_filtered'] / df_cas22['num_peaks_filtered'])" ] }, { @@ -22128,7 +22127,7 @@ } ], "source": [ - "np.median(df_cas22[\"coverage\"])" + "np.median(df_cas22['coverage'])" ] }, { @@ -22148,7 +22147,7 @@ } ], "source": [ - "np.sum(df_cas[\"spectral_sqrt_cosine\"] >= 0.70) / df_cas.shape[0]" + "np.sum(df_cas['spectral_sqrt_cosine'] >= 0.70) / df_cas.shape[0]" ] }, { @@ -22168,7 +22167,7 @@ } ], "source": [ - "np.mean(df_cas22[\"num_peak_matches_filtered\"] / df_cas22[\"num_peaks_filtered\"])" + "np.mean(df_cas22['num_peak_matches_filtered'] / df_cas22['num_peaks_filtered'])" ] }, { @@ -22177,7 +22176,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_test[\"CE_10\"] = np.ceil(df_test[\"CE\"] / 10) * 10\n", + "df_test['CE_10'] = np.ceil(df_test['CE'] / 10) * 10\n", "set_light_theme()" ] }, @@ -22225,7 +22224,7 @@ " 2,\n", " 2,\n", " figsize=(16, 12),\n", - " gridspec_kw={\"width_ratios\": [1, 1]},\n", + " gridspec_kw={'width_ratios': [1, 1]},\n", " sharex=True,\n", " sharey=True,\n", ")\n", @@ -22233,49 +22232,49 @@ " hspace=0.12, wspace=0.05\n", ") # top=0.94, bottom=0.12, right=0.97, left=0.08)\n", "\n", - "magma10 = sns.color_palette(\"magma_r\", 18)[:10]\n", + "magma10 = sns.color_palette('magma_r', 18)[:10]\n", "sns.boxplot(\n", " ax=axs[0, 0],\n", " data=df_test,\n", - " y=\"spectral_sqrt_cosine\",\n", - " x=\"CE_10\",\n", + " y='spectral_sqrt_cosine',\n", + " x='CE_10',\n", " showfliers=False,\n", " palette=magma10,\n", " linewidth=1.5,\n", ") # color=\"white\", linewidth=2, linecolor=\"black\")\n", - "axs[0, 0].set_title(\"Fiora\", fontweight=\"bold\")\n", - "axs[0, 0].set_ylabel(\"Cosine similarity\")\n", + "axs[0, 0].set_title('Fiora', fontweight='bold')\n", + "axs[0, 0].set_ylabel('Cosine similarity')\n", "\n", "sns.boxplot(\n", " ax=axs[0, 1],\n", " data=df_test,\n", - " y=\"ice_sqrt_cosine\",\n", - " x=\"CE_10\",\n", + " y='ice_sqrt_cosine',\n", + " x='CE_10',\n", " showfliers=False,\n", " palette=magma10,\n", " linewidth=1.5,\n", ")\n", - "axs[0, 1].set_title(\"ICEBERG\", fontweight=\"bold\")\n", - "axs[0, 1].set_xlabel(\"Collision energy\")\n", - "axs[0, 1].xaxis.set_label_position(\"bottom\")\n", + "axs[0, 1].set_title('ICEBERG', fontweight='bold')\n", + "axs[0, 1].set_xlabel('Collision energy')\n", + "axs[0, 1].xaxis.set_label_position('bottom')\n", "axs[0, 1].set_xticklabels(list(range(10, 110, 10)))\n", "axs[0, 1].tick_params(\n", - " axis=\"x\", which=\"both\", bottom=True, top=False, labelbottom=True\n", + " axis='x', which='both', bottom=True, top=False, labelbottom=True\n", ") # Ensure tick labels visibility\n", "axs[0, 1].xaxis.label.set_visible(True)\n", "\n", "sns.boxplot(\n", " ax=axs[1, 0],\n", " data=df_test,\n", - " y=\"cfm_sqrt_cosine\",\n", - " x=\"CE_10\",\n", + " y='cfm_sqrt_cosine',\n", + " x='CE_10',\n", " showfliers=False,\n", " palette=magma10,\n", " linewidth=1.5,\n", ")\n", - "axs[1, 0].set_title(\"CFM-ID\", fontweight=\"bold\")\n", - "axs[1, 0].set_xlabel(\"Collision energy\")\n", - "axs[1, 0].set_ylabel(\"Cosine similarity\")\n", + "axs[1, 0].set_title('CFM-ID', fontweight='bold')\n", + "axs[1, 0].set_xlabel('Collision energy')\n", + "axs[1, 0].set_ylabel('Cosine similarity')\n", "axs[1, 1].remove()\n", "\n", "for i in range(2):\n", @@ -22283,32 +22282,32 @@ " if axs[i, j].has_data(): # Check if the subplot has data\n", " for tick, label in zip(axs[i, j].get_xticks(), axs[i, j].get_xticklabels()):\n", " count = len(\n", - " df_test[df_test[\"CE_10\"] == int(label.get_text())]\n", + " df_test[df_test['CE_10'] == int(label.get_text())]\n", " ) # Count data points per category\n", " axs[i, j].text(\n", " tick,\n", " axs[i, j].get_ylim()[0] + 0.04,\n", - " f\"n={count}\",\n", - " ha=\"center\",\n", - " va=\"top\",\n", + " f'n={count}',\n", + " ha='center',\n", + " va='top',\n", " fontsize=10.25,\n", " )\n", " if i == 0 and j == 1:\n", " axs[0, 0].text(\n", " tick,\n", " axs[0, 0].get_ylim()[0] + 0.04,\n", - " f\"n={count}\",\n", - " ha=\"center\",\n", - " va=\"top\",\n", + " f'n={count}',\n", + " ha='center',\n", + " va='top',\n", " fontsize=10.25,\n", " )\n", "\n", - "plt.rc(\"axes\", labelsize=14)\n", - "plt.rc(\"legend\", fontsize=14)\n", + "plt.rc('axes', labelsize=14)\n", + "plt.rc('legend', fontsize=14)\n", "plt.ylim([-0.05, 1.05])\n", - "axs[0, 0].tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", - "axs[0, 1].tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", - "axs[1, 0].tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", + "axs[0, 0].tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", + "axs[0, 1].tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", + "axs[1, 0].tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", "\n", "# fig.savefig(f\"{home}/images/paper/cosine_ce_wo_prec.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/cosine_ce.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -22360,27 +22359,27 @@ "\n", "sns.boxplot(\n", " data=df_test,\n", - " y=\"coverage\",\n", - " x=\"CE_10\",\n", + " y='coverage',\n", + " x='CE_10',\n", " showfliers=False,\n", " palette=magma10,\n", " linewidth=1.5,\n", ") # color=\"white\", linewidth=2, linecolor=\"black\")\n", "ax.set_xticklabels(list(range(10, 110, 10)))\n", - "ax.set_xlabel(\"Collision energy\")\n", - "ax.set_ylabel(\"Peak intensity coverage\")\n", + "ax.set_xlabel('Collision energy')\n", + "ax.set_ylabel('Peak intensity coverage')\n", "\n", "# Add 'n=XXX' annotations\n", "for tick, label in zip(ax.get_xticks(), ax.get_xticklabels()):\n", - " count = len(df_test[df_test[\"CE_10\"] == int(label.get_text())])\n", + " count = len(df_test[df_test['CE_10'] == int(label.get_text())])\n", " ax.text(\n", - " tick, ax.get_ylim()[0] + 0.01, f\"n={count}\", ha=\"center\", va=\"top\", fontsize=13\n", + " tick, ax.get_ylim()[0] + 0.01, f'n={count}', ha='center', va='top', fontsize=13\n", " )\n", "\n", "plt.ylim([0.45, 1.03])\n", - "plt.rc(\"axes\", labelsize=14)\n", - "plt.rc(\"legend\", fontsize=14)\n", - "ax.tick_params(axis=\"both\", which=\"major\", labelsize=13) # For major ticks\n", + "plt.rc('axes', labelsize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "ax.tick_params(axis='both', which='major', labelsize=13) # For major ticks\n", "# plt.subplots_adjust(bottom=0.2)\n", "\n", "# fig.savefig(f\"{home}/images/paper/coverage_ce.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -22408,18 +22407,18 @@ } ], "source": [ - "magma = sns.color_palette(\"magma_r\", 4)\n", + "magma = sns.color_palette('magma_r', 4)\n", "fig, axs = plt.subplots(\n", - " 1, 3, figsize=(18, 6), gridspec_kw={\"width_ratios\": [1, 1, 1]}, sharey=True\n", + " 1, 3, figsize=(18, 6), gridspec_kw={'width_ratios': [1, 1, 1]}, sharey=True\n", ")\n", "plt.subplots_adjust(wspace=0.05)\n", "\n", "sns.boxplot(\n", " ax=axs[0],\n", " data=df_cas22,\n", - " y=\"spectral_sqrt_cosine\",\n", - " x=\"NCE\",\n", - " hue=\"NCE\",\n", + " y='spectral_sqrt_cosine',\n", + " x='NCE',\n", + " hue='NCE',\n", " palette=magma[:3],\n", " showfliers=False,\n", ")\n", @@ -22428,47 +22427,47 @@ "sns.boxplot(\n", " ax=axs[1],\n", " data=df_cas22,\n", - " y=\"cfm_sqrt_cosine\",\n", - " x=\"NCE\",\n", - " hue=\"NCE\",\n", + " y='cfm_sqrt_cosine',\n", + " x='NCE',\n", + " hue='NCE',\n", " palette=magma[:3],\n", " showfliers=False,\n", ")\n", "sns.boxplot(\n", " ax=axs[2],\n", " data=df_cas22,\n", - " y=\"ice_sqrt_cosine\",\n", - " x=\"NCE\",\n", - " hue=\"NCE\",\n", + " y='ice_sqrt_cosine',\n", + " x='NCE',\n", + " hue='NCE',\n", " palette=magma[:3],\n", " showfliers=False,\n", ")\n", - "axs[0].set_title(\"Fiora\", fontweight=\"bold\", fontsize=16)\n", - "axs[1].set_title(\"CFM-ID\", fontweight=\"bold\", fontsize=16)\n", - "axs[2].set_title(\"ICEBERG\", fontweight=\"bold\", fontsize=16)\n", - "axs[0].set_ylabel(\"Cosine similarity\")\n", + "axs[0].set_title('Fiora', fontweight='bold', fontsize=16)\n", + "axs[1].set_title('CFM-ID', fontweight='bold', fontsize=16)\n", + "axs[2].set_title('ICEBERG', fontweight='bold', fontsize=16)\n", + "axs[0].set_ylabel('Cosine similarity')\n", "for ax in axs:\n", " ax.get_legend().remove()\n", "\n", - "plt.rc(\"axes\", labelsize=14)\n", - "plt.rc(\"legend\", fontsize=14)\n", + "plt.rc('axes', labelsize=14)\n", + "plt.rc('legend', fontsize=14)\n", "plt.ylim(-0.08, 1.04)\n", - "axs[0].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", - "axs[1].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", - "axs[2].tick_params(axis=\"both\", which=\"major\", labelsize=13)\n", + "axs[0].tick_params(axis='both', which='major', labelsize=13)\n", + "axs[1].tick_params(axis='both', which='major', labelsize=13)\n", + "axs[2].tick_params(axis='both', which='major', labelsize=13)\n", "\n", "for ax in axs:\n", " for tick, label in zip(ax.get_xticks(), ax.get_xticklabels()):\n", " category = float(label.get_text()) # Extract the label text\n", " count = len(\n", - " df_cas22[df_cas22[\"NCE\"] == category]\n", + " df_cas22[df_cas22['NCE'] == category]\n", " ) # Count samples for the category\n", " ax.text(\n", " tick,\n", " ax.get_ylim()[0] + 0.05,\n", - " f\"n={count}\",\n", - " ha=\"center\",\n", - " va=\"top\",\n", + " f'n={count}',\n", + " ha='center',\n", + " va='top',\n", " fontsize=13,\n", " )\n", "\n", @@ -22516,48 +22515,48 @@ "\n", "\n", "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(12, 14), sharex=True, gridspec_kw={\"height_ratios\": [1, 5]}\n", + " 2, 1, figsize=(12, 14), sharex=True, gridspec_kw={'height_ratios': [1, 5]}\n", ")\n", "plt.subplots_adjust(hspace=0.025) # right=0.975, left=0.11)\n", "# sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", - "plt.rc(\"legend\", loc=\"upper center\")\n", + "plt.rc('legend', loc='upper center')\n", "sns.kdeplot(\n", " ax=axs[0],\n", - " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", - " x=\"coverage\",\n", + " data=CAT[CAT['Dataset'].isin(['CASMI 16', 'CASMI 22'])],\n", + " x='coverage',\n", " bw_adjust=0.2,\n", - " color=\"black\",\n", + " color='black',\n", " fill=True,\n", - " multiple=\"layer\",\n", - " hue=\"Dataset\",\n", + " multiple='layer',\n", + " hue='Dataset',\n", " common_norm=False,\n", " palette=tri_palette[1:],\n", ") # hue=\"Precursor_type\",\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "# axs[0].legend(loc='upper center')\n", "plt.xlim([-0.05, 1.05])\n", "axs[1].set_ylim([-0.05, 1.05])\n", "# Generating x values\n", "x = np.linspace(0, 1, 200)\n", "y = np.sqrt(x)\n", - "sns.lineplot(x=x, y=y, color=\"black\", linestyle=\"dotted\")\n", + "sns.lineplot(x=x, y=y, color='black', linestyle='dotted')\n", "\n", "# axs[0].set_title(\"Impact of coverage on cosine scores\")\n", "sns.scatterplot(\n", " ax=axs[1],\n", - " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", - " x=\"coverage\",\n", - " y=\"spectral_sqrt_cosine\",\n", - " hue=\"Dataset\",\n", - " style=\"Dataset\",\n", - " markers=[(4, 1, 0), \".\"],\n", + " data=CAT[CAT['Dataset'].isin(['CASMI 16', 'CASMI 22'])],\n", + " x='coverage',\n", + " y='spectral_sqrt_cosine',\n", + " hue='Dataset',\n", + " style='Dataset',\n", + " markers=[(4, 1, 0), '.'],\n", " s=100,\n", " palette=tri_palette[1:],\n", ") # , hue_norm=(0, 1), palette=bluepink_grad)\n", - "axs[1].set_ylabel(\"Cosine similarity\")\n", - "axs[1].set_xlabel(\"Peak intensity coverage\")\n", - "axs[1].legend(title=\"Dataset\", loc=\"upper left\")\n", + "axs[1].set_ylabel('Cosine similarity')\n", + "axs[1].set_xlabel('Peak intensity coverage')\n", + "axs[1].legend(title='Dataset', loc='upper left')\n", "# fig.savefig(f\"{home}/images/paper/coverage_top_only.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/coverage_top_only.pdf\", format=\"pdf\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "# fig.savefig(f\"{home}/images/paper/coverage_top_only.png\", format=\"png\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -22597,84 +22596,84 @@ " figsize=(14, 14),\n", " sharex=False,\n", " sharey=False,\n", - " gridspec_kw={\"height_ratios\": [1, 5], \"width_ratios\": [5, 1]},\n", + " gridspec_kw={'height_ratios': [1, 5], 'width_ratios': [5, 1]},\n", ")\n", "plt.subplots_adjust(hspace=0.025, wspace=0.025) # hspace=0.025)#right=0.975, left=0.11)\n", "# sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", - "plt.rc(\"legend\", loc=\"upper center\")\n", - "axs[0, 0].tick_params(axis=\"both\", labelsize=13)\n", - "axs[1, 0].tick_params(axis=\"both\", labelsize=13)\n", - "axs[1, 1].tick_params(axis=\"both\", labelsize=13)\n", - "plt.rc(\"axes\", labelsize=20)\n", - "plt.rc(\"legend\", fontsize=14)\n", + "plt.rc('legend', loc='upper center')\n", + "axs[0, 0].tick_params(axis='both', labelsize=13)\n", + "axs[1, 0].tick_params(axis='both', labelsize=13)\n", + "axs[1, 1].tick_params(axis='both', labelsize=13)\n", + "plt.rc('axes', labelsize=20)\n", + "plt.rc('legend', fontsize=14)\n", "\n", "\n", "sns.kdeplot(\n", " ax=axs[0, 0],\n", - " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", - " x=\"coverage\",\n", + " data=CAT[CAT['Dataset'].isin(['CASMI 16', 'CASMI 22'])],\n", + " x='coverage',\n", " bw_adjust=0.2,\n", - " color=\"black\",\n", + " color='black',\n", " fill=True,\n", - " multiple=\"layer\",\n", - " hue=\"Dataset\",\n", + " multiple='layer',\n", + " hue='Dataset',\n", " common_norm=False,\n", " palette=tri_palette[1:],\n", ") # hue=\"Precursor_type\",\n", - "axs[0, 0].spines[\"top\"].set_visible(False)\n", - "axs[0, 0].spines[\"right\"].set_visible(False)\n", + "axs[0, 0].spines['top'].set_visible(False)\n", + "axs[0, 0].spines['right'].set_visible(False)\n", "axs[0, 0].set_xlim([-0.05, 1.05])\n", - "axs[0, 0].set_xlabel(\"\")\n", - "axs[0, 0].set_xticklabels(\"\")\n", + "axs[0, 0].set_xlabel('')\n", + "axs[0, 0].set_xticklabels('')\n", "# axs[0].legend(loc='upper center')\n", "axs[1, 0].set_xlim([-0.05, 1.05])\n", "axs[1, 0].set_ylim([-0.05, 1.05])\n", - "axs[1, 0].set_aspect(\"equal\")\n", + "axs[1, 0].set_aspect('equal')\n", "# Generating x values\n", "x = np.linspace(0, 1, 200)\n", "y = np.sqrt(x)\n", - "sns.lineplot(x=x, y=y, color=\"black\", linestyle=\"dotted\", ax=axs[1, 0])\n", + "sns.lineplot(x=x, y=y, color='black', linestyle='dotted', ax=axs[1, 0])\n", "\n", "\n", "# axs[0].set_title(\"Impact of coverage on cosine scores\")\n", "sns.scatterplot(\n", " ax=axs[1, 0],\n", - " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", - " x=\"coverage\",\n", - " y=\"spectral_sqrt_cosine\",\n", - " hue=\"Dataset\",\n", - " style=\"Dataset\",\n", - " markers=[(4, 1, 0), \".\"],\n", + " data=CAT[CAT['Dataset'].isin(['CASMI 16', 'CASMI 22'])],\n", + " x='coverage',\n", + " y='spectral_sqrt_cosine',\n", + " hue='Dataset',\n", + " style='Dataset',\n", + " markers=[(4, 1, 0), '.'],\n", " s=100,\n", " palette=tri_palette[1:],\n", ") # , hue_norm=(0, 1), palette=bluepink_grad)\n", - "axs[1, 0].set_ylabel(\"Cosine similarity\")\n", - "axs[1, 0].set_xlabel(\"Peak intensity coverage\")\n", - "axs[1, 0].legend(title=\"Dataset\", loc=\"upper left\")\n", + "axs[1, 0].set_ylabel('Cosine similarity')\n", + "axs[1, 0].set_xlabel('Peak intensity coverage')\n", + "axs[1, 0].legend(title='Dataset', loc='upper left')\n", "\n", "\n", "# Box plots\n", "\n", "sns.kdeplot(\n", " ax=axs[1, 1],\n", - " data=CAT[CAT[\"Dataset\"].isin([\"CASMI 16\", \"CASMI 22\"])],\n", - " y=\"spectral_sqrt_cosine\",\n", + " data=CAT[CAT['Dataset'].isin(['CASMI 16', 'CASMI 22'])],\n", + " y='spectral_sqrt_cosine',\n", " bw_adjust=0.2,\n", - " color=\"black\",\n", + " color='black',\n", " fill=True,\n", - " multiple=\"layer\",\n", - " hue=\"Dataset\",\n", + " multiple='layer',\n", + " hue='Dataset',\n", " common_norm=False,\n", " palette=tri_palette[1:],\n", ") # hue=\"Precursor_type\",\n", "# sns.boxplot(ax=axs[1,1], data=CAT[CAT[\"Dataset\"] != \"Test split\"], y=\"spectral_sqrt_cosine\", hue=\"Dataset\", palette=tri_palette[1:])\n", "\n", "axs[1, 1].set_ylim(axs[1, 0].get_ylim())\n", - "axs[1, 1].set_ylabel(\"\")\n", - "axs[1, 1].set_yticklabels(\"\")\n", + "axs[1, 1].set_ylabel('')\n", + "axs[1, 1].set_yticklabels('')\n", "axs[1, 1].legend().remove()\n", - "axs[1, 1].spines[\"top\"].set_visible(False)\n", - "axs[1, 1].spines[\"right\"].set_visible(False)\n", + "axs[1, 1].spines['top'].set_visible(False)\n", + "axs[1, 1].spines['right'].set_visible(False)\n", "fig.delaxes(axs[0, 1])\n", "\n", "for ax in axs.flat:\n", @@ -22683,11 +22682,11 @@ " ax.set_xlabel(ax.get_xlabel(), fontsize=14)\n", " ax.set_ylabel(ax.get_ylabel(), fontsize=14)\n", "\n", - "axs[0, 0].tick_params(axis=\"both\", labelsize=13)\n", - "axs[1, 0].tick_params(axis=\"both\", labelsize=13)\n", - "axs[1, 1].tick_params(axis=\"both\", labelsize=13)\n", - "plt.rc(\"axes\", labelsize=20)\n", - "plt.rc(\"legend\", fontsize=14)\n", + "axs[0, 0].tick_params(axis='both', labelsize=13)\n", + "axs[1, 0].tick_params(axis='both', labelsize=13)\n", + "axs[1, 1].tick_params(axis='both', labelsize=13)\n", + "plt.rc('axes', labelsize=20)\n", + "plt.rc('legend', fontsize=14)\n", "\n", "\n", "# fig.savefig(f\"{home}/images/paper/coverage.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", @@ -22736,12 +22735,12 @@ "normalized_x1 = x1 / magnitude_x1\n", "normalized_y1 = y1 / magnitude_y1\n", "\n", - "print(\"Normalized x1:\", normalized_x1)\n", - "print(\"Normalized y1:\", normalized_y1)\n", + "print('Normalized x1:', normalized_x1)\n", + "print('Normalized y1:', normalized_y1)\n", "\n", "from fiora.MS.spectral_scores import cosine\n", "\n", - "print(\"Cosine:\", cosine(normalized_x1, normalized_y1))" + "print('Cosine:', cosine(normalized_x1, normalized_y1))" ] }, { @@ -22774,39 +22773,39 @@ ], "source": [ "fig, axs = plt.subplots(\n", - " 2, 1, figsize=(6, 7), sharex=True, gridspec_kw={\"height_ratios\": [1, 5]}\n", + " 2, 1, figsize=(6, 7), sharex=True, gridspec_kw={'height_ratios': [1, 5]}\n", ")\n", "plt.subplots_adjust(hspace=0.05) # right=0.975, left=0.11)\n", "# sns.histplot(ax=axs[0], data=df_cas22, x=\"coverage\", binwidth=0.02, kde_kws={\"bw_adjust\": 0.1}, multiple=\"stack\", kde=True, color=\"black\", palette=[\"black\", \"gray\"]) #hue=\"Precursor_type\",\n", "sns.kdeplot(\n", " ax=axs[0],\n", " data=df_test,\n", - " x=\"coverage\",\n", + " x='coverage',\n", " bw_adjust=0.2,\n", - " color=\"black\",\n", + " color='black',\n", " fill=True,\n", - " multiple=\"layer\",\n", - " hue=\"Dataset\",\n", + " multiple='layer',\n", + " hue='Dataset',\n", " common_norm=False,\n", " palette=tri_palette[1:],\n", ") # hue=\"Precursor_type\",\n", - "axs[0].spines[\"top\"].set_visible(False)\n", - "axs[0].spines[\"right\"].set_visible(False)\n", + "axs[0].spines['top'].set_visible(False)\n", + "axs[0].spines['right'].set_visible(False)\n", "\n", - "axs[0].set_title(\"Impact of coverage on cosine scores\")\n", + "axs[0].set_title('Impact of coverage on cosine scores')\n", "sns.scatterplot(\n", " ax=axs[1],\n", " data=df_test,\n", - " x=\"coverage\",\n", - " y=\"spectral_sqrt_cosine\",\n", - " hue=\"Dataset\",\n", - " style=\"Dataset\",\n", - " markers=[\".\", \"X\", \"*\"][1:],\n", - " marker=\".\",\n", + " x='coverage',\n", + " y='spectral_sqrt_cosine',\n", + " hue='Dataset',\n", + " style='Dataset',\n", + " markers=['.', 'X', '*'][1:],\n", + " marker='.',\n", " palette=tri_palette[1:],\n", ") # , hue_norm=(0, 1), palette=bluepink_grad)\n", - "axs[1].set_ylabel(\"Cosine similarity\")\n", - "axs[1].set_xlabel(\"Peak intensity coverage\")\n", + "axs[1].set_ylabel('Cosine similarity')\n", + "axs[1].set_xlabel('Peak intensity coverage')\n", "plt.show()" ] }, @@ -22982,14 +22981,14 @@ "source": [ "ids = [17, 58, 68, 75, 102, 107, 128, 145, 163, 205]\n", "example_input = (\n", - " df_cas[df_cas[\"is_priority\"]]\n", + " df_cas[df_cas['is_priority']]\n", " .loc[ids]\n", - " .copy()[[\"SMILES\", \"avg_CE\", \"Precursor_type\", \"Instrument_type\", \"peaks\"]]\n", + " .copy()[['SMILES', 'avg_CE', 'Precursor_type', 'Instrument_type', 'peaks']]\n", ")\n", - "example_input[\"CE\"] = example_input[\"avg_CE\"].astype(int)\n", - "example_input[\"Name\"] = [f\"Example_{i}\" for i in range(example_input.shape[0])]\n", + "example_input['CE'] = example_input['avg_CE'].astype(int)\n", + "example_input['Name'] = [f'Example_{i}' for i in range(example_input.shape[0])]\n", "# example_input[[\"Name\", \"SMILES\", \"Precursor_type\", \"CE\", \"Instrument_type\"]].to_csv(\"../examples/example_input.csv\", index=False)\n", - "example_input[[\"Name\", \"SMILES\", \"Precursor_type\", \"CE\", \"Instrument_type\"]]" + "example_input[['Name', 'SMILES', 'Precursor_type', 'CE', 'Instrument_type']]" ] }, { @@ -22998,14 +22997,14 @@ "metadata": {}, "outputs": [], "source": [ - "for key in df_msnlib_test.iloc[0][\"Metabolite\"].match_stats.keys():\n", - " df_msnlib_test[key] = df_msnlib_test[\"Metabolite\"].apply(\n", + "for key in df_msnlib_test.iloc[0]['Metabolite'].match_stats.keys():\n", + " df_msnlib_test[key] = df_msnlib_test['Metabolite'].apply(\n", " lambda x: x.match_stats[key]\n", " )\n", - " df_test[key] = df_test[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n", - " df_cas[key] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n", - " df_cas22[key] = df_cas22[\"Metabolite\"].apply(lambda x: x.match_stats[key])\n", - " C[key] = C[\"Metabolite\"].apply(lambda x: x.match_stats[key])" + " df_test[key] = df_test['Metabolite'].apply(lambda x: x.match_stats[key])\n", + " df_cas[key] = df_cas['Metabolite'].apply(lambda x: x.match_stats[key])\n", + " df_cas22[key] = df_cas22['Metabolite'].apply(lambda x: x.match_stats[key])\n", + " C[key] = C['Metabolite'].apply(lambda x: x.match_stats[key])" ] }, { @@ -23030,7 +23029,7 @@ } ], "source": [ - "C.groupby(\"Precursor_type\")[\"precursor_raw_prob\"].mean()" + "C.groupby('Precursor_type')['precursor_raw_prob'].mean()" ] }, { @@ -23057,7 +23056,7 @@ } ], "source": [ - "df_msnlib_test[\"group_id\"].nunique()" + "df_msnlib_test['group_id'].nunique()" ] }, { @@ -23091,7 +23090,7 @@ } ], "source": [ - "sns.boxplot(df_msnlib_test, x=\"num_peak_matches\", y=\"spectral_sqrt_cosine_wo_prec\")\n", + "sns.boxplot(df_msnlib_test, x='num_peak_matches', y='spectral_sqrt_cosine_wo_prec')\n", "plt.show()" ] }, @@ -23101,7 +23100,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_low = df_msnlib_test[df_msnlib_test[\"spectral_sqrt_cosine\"] < 0.3]" + "df_low = df_msnlib_test[df_msnlib_test['spectral_sqrt_cosine'] < 0.3]" ] }, { @@ -23211,20 +23210,20 @@ "reset_matplotlib()\n", "\n", "df_print = df_cas[\n", - " (df_cas[\"merged_sqrt_cosine\"] > 0.70) & (df_cas[\"merged_sqrt_bias\"] < 0.6)\n", + " (df_cas['merged_sqrt_cosine'] > 0.70) & (df_cas['merged_sqrt_bias'] < 0.6)\n", "]\n", "# df_print = df_test[(df_test[\"spectral_sqrt_cosine\"] > 0.85) & (df_test[\"spectral_sqrt_bias\"] < 0.6) & (df_test[\"lib\"] == \"MSDIAL\")]\n", "\n", "print(df_print.shape)\n", "for i, data in df_print.head(5).iterrows():\n", " fig, axs = plt.subplots(\n", - " 1, 2, figsize=(9, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(9, 3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", " )\n", - " img = data[\"Metabolite\"].draw(ax=axs[0])\n", - " print(i, data[\"ChallengeName\"])\n", + " img = data['Metabolite'].draw(ax=axs[0])\n", + " print(i, data['ChallengeName'])\n", " sv.plot_spectrum(\n", " data,\n", - " {\"peaks\": data[\"sim_peaks\"]},\n", + " {'peaks': data['sim_peaks']},\n", " highlight_matches=True,\n", " ppm_tolerance=200,\n", " ax=axs[1],\n", @@ -23239,7 +23238,7 @@ "outputs": [], "source": [ "def peaks_by_intensity(peaks):\n", - " z = zip(peaks[\"mz\"], peaks[\"intensity\"], peaks[\"annotation\"])\n", + " z = zip(peaks['mz'], peaks['intensity'], peaks['annotation'])\n", " return sorted(z, key=lambda x: x[1], reverse=True)" ] }, @@ -23549,18 +23548,18 @@ "\n", "for i in smallbutrelatable:\n", " data = df_test.loc[i]\n", - " print(i, data[\"CE\"], data[\"Precursor_type\"])\n", - " f = peaks_by_intensity(data[\"sim_peaks\"])\n", + " print(i, data['CE'], data['Precursor_type'])\n", + " f = peaks_by_intensity(data['sim_peaks'])\n", " fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5))\n", - " Metabolite(f[0][2].split(\"//\")[0]).draw()\n", + " Metabolite(f[0][2].split('//')[0]).draw()\n", " plt.show()\n", " fig, axs = plt.subplots(\n", - " 1, 2, figsize=(9, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(9, 3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", " )\n", - " img = data[\"Metabolite\"].draw(ax=axs[0])\n", + " img = data['Metabolite'].draw(ax=axs[0])\n", " sv.plot_spectrum(\n", " data,\n", - " {\"peaks\": data[\"sim_peaks\"]},\n", + " {'peaks': data['sim_peaks']},\n", " highlight_matches=True,\n", " ppm_tolerance=200,\n", " ax=axs[1],\n", @@ -23592,9 +23591,9 @@ } ], "source": [ - "df_train[df_train[\"Name\"] == \"Indole-3-acetyl-L-alanine\"][\"Metabolite\"].iloc[\n", + "df_train[df_train['Name'] == 'Indole-3-acetyl-L-alanine']['Metabolite'].iloc[\n", " 0\n", - "].tanimoto_similarity(data[\"Metabolite\"])" + "].tanimoto_similarity(data['Metabolite'])" ] }, { @@ -23630,22 +23629,22 @@ "source": [ "### GOLD EXAMPLE\n", "data = df_test.loc[80082]\n", - "print(data[\"Name\"], data[\"group_id\"])\n", + "print(data['Name'], data['group_id'])\n", "reset_matplotlib()\n", "\n", - "f = peaks_by_intensity(data[\"sim_peaks\"])\n", + "f = peaks_by_intensity(data['sim_peaks'])\n", "print(f[0])\n", - "Metabolite(f[0][2].split(\"//\")[0]).draw()\n", + "Metabolite(f[0][2].split('//')[0]).draw()\n", "plt.show()\n", "fig, axs = plt.subplots(\n", - " 1, 2, figsize=(9, 3), gridspec_kw={\"width_ratios\": [1, 3]}, sharey=False\n", + " 1, 2, figsize=(9, 3), gridspec_kw={'width_ratios': [1, 3]}, sharey=False\n", ")\n", - "img = data[\"Metabolite\"].draw(ax=axs[0])\n", + "img = data['Metabolite'].draw(ax=axs[0])\n", "\n", "\n", "sv.plot_spectrum(\n", " data,\n", - " {\"peaks\": data[\"sim_peaks\"]},\n", + " {'peaks': data['sim_peaks']},\n", " highlight_matches=True,\n", " ppm_tolerance=200,\n", " ax=axs[1],\n", @@ -23660,7 +23659,7 @@ "\n", "plt.show()\n", "fig, ax = plt.subplots(1, 1, figsize=(9, 3))\n", - "sv.plot_spectrum({\"peaks\": data[\"sim_peaks\"]}, ax=ax)\n", + "sv.plot_spectrum({'peaks': data['sim_peaks']}, ax=ax)\n", "# fig.savefig(f\"{home}/images/paper/ex_prediction.svg\", format=\"svg\", dpi=600, bbox_inches='tight', pad_inches=0.1)\n", "\n", "plt.show()" @@ -23803,7 +23802,7 @@ } ], "source": [ - "df_test[df_test[\"Name\"] == \"Indole-3-acetyl-L-alanine\"]" + "df_test[df_test['Name'] == 'Indole-3-acetyl-L-alanine']" ] }, { @@ -23826,8 +23825,7 @@ "import rdkit.Chem as Chem\n", "from rdkit.Chem import Descriptors\n", "\n", - "\n", - "h_plus = Chem.MolFromSmiles(\"[H+]\")\n", + "h_plus = Chem.MolFromSmiles('[H+]')\n", "Chem.Descriptors.ExactMolWt(h_plus)" ] }, @@ -23848,7 +23846,7 @@ } ], "source": [ - "s = \"C=c1c[nH+]c2ccccc12\"\n", + "s = 'C=c1c[nH+]c2ccccc12'\n", "# s=\"C=C1C=[NH+]C2=C1C=CC=C2\"\n", "m = Metabolite(s)\n", "m.draw()\n", @@ -23873,15 +23871,15 @@ } ], "source": [ - "from rdkit.Chem.Draw import rdMolDraw2D\n", "import cairosvg\n", + "from rdkit.Chem.Draw import rdMolDraw2D\n", "\n", "save_fragments = True\n", "\n", "for i in range(len(f)):\n", " print(f[i])\n", " fig, ax = plt.subplots(1, 1, figsize=(2, 2))\n", - " m = Metabolite(f[i][2].split(\"//\")[0])\n", + " m = Metabolite(f[i][2].split('//')[0])\n", " m.draw()\n", " plt.show()\n", " if save_fragments:\n", @@ -23890,11 +23888,11 @@ " drawer.FinishDrawing()\n", " cairosvg.svg2pdf(\n", " bytestring=drawer.GetDrawingText().encode(),\n", - " write_to=f\"{home}/images/paper/molecule_f{i}.pdf\",\n", + " write_to=f'{home}/images/paper/molecule_f{i}.pdf',\n", " )\n", " cairosvg.svg2svg(\n", " bytestring=drawer.GetDrawingText().encode(),\n", - " write_to=f\"{home}/images/paper/molecule_f{i}.svg\",\n", + " write_to=f'{home}/images/paper/molecule_f{i}.svg',\n", " )" ] }, @@ -23915,23 +23913,23 @@ } ], "source": [ - "from rdkit.Chem.Draw import rdMolDraw2D\n", "import cairosvg\n", + "from rdkit.Chem.Draw import rdMolDraw2D\n", "\n", "# fig, ax = plt.subplots(1,1, figsize=(10, 10))\n", "# data[\"Metabolite\"].draw(ax= ax)\n", "\n", "if False:\n", " drawer = rdMolDraw2D.MolDraw2DSVG(500, 500)\n", - " drawer.DrawMolecule(data[\"Metabolite\"].MOL)\n", + " drawer.DrawMolecule(data['Metabolite'].MOL)\n", " drawer.FinishDrawing()\n", " cairosvg.svg2pdf(\n", " bytestring=drawer.GetDrawingText().encode(),\n", - " write_to=f\"{home}/images/paper/molecule.pdf\",\n", + " write_to=f'{home}/images/paper/molecule.pdf',\n", " )\n", " cairosvg.svg2svg(\n", " bytestring=drawer.GetDrawingText().encode(),\n", - " write_to=f\"{home}/images/paper/molecule.svg\",\n", + " write_to=f'{home}/images/paper/molecule.svg',\n", " )" ] }, @@ -23955,53 +23953,53 @@ "import matplotlib.patches as mpatches\n", "\n", "\n", - "def double_mirrorplot(i, model_title=\"Fiora\"):\n", + "def double_mirrorplot(i, model_title='Fiora'):\n", " fig, axs = plt.subplots(\n", - " 1, 3, figsize=(16.8, 4.2), gridspec_kw={\"width_ratios\": [1, 3, 3]}, sharey=False\n", + " 1, 3, figsize=(16.8, 4.2), gridspec_kw={'width_ratios': [1, 3, 3]}, sharey=False\n", " )\n", "\n", " plt.subplots_adjust(right=0.975, left=0.025)\n", "\n", - " img = df_cas.loc[i][\"Metabolite\"].draw(ax=axs[0])\n", + " img = df_cas.loc[i]['Metabolite'].draw(ax=axs[0])\n", "\n", " axs[0].grid(False)\n", " axs[0].tick_params(\n", - " axis=\"both\", bottom=False, labelbottom=False, left=False, labelleft=False\n", + " axis='both', bottom=False, labelbottom=False, left=False, labelleft=False\n", " )\n", " axs[0].set_title(\n", - " df_cas.loc[i][\"NAME\"] + \"\\n(\" + df_cas.loc[i][\"ChallengeName\"] + \")\"\n", + " df_cas.loc[i]['NAME'] + '\\n(' + df_cas.loc[i]['ChallengeName'] + ')'\n", " )\n", " axs[0].imshow(img)\n", - " axs[0].axis(\"off\")\n", + " axs[0].axis('off')\n", "\n", " sv.plot_spectrum(\n", - " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", - " {\"peaks\": df_cas.loc[i][\"merged_peaks\"]},\n", + " {'peaks': df_cas.loc[i]['peaks']},\n", + " {'peaks': df_cas.loc[i]['merged_peaks']},\n", " ax=axs[1],\n", " )\n", " axs[1].title.set_text(model_title)\n", " patch1 = mpatches.Patch(\n", - " color=\"limegreen\"\n", - " if df_cas.loc[i][\"ice_sqrt_cosine\"] < df_cas.loc[i][\"merged_sqrt_cosine\"]\n", - " else \"orangered\",\n", - " label=f\"cosine {df_cas.loc[i]['merged_sqrt_cosine']:.02f}\",\n", + " color='limegreen'\n", + " if df_cas.loc[i]['ice_sqrt_cosine'] < df_cas.loc[i]['merged_sqrt_cosine']\n", + " else 'orangered',\n", + " label=f'cosine {df_cas.loc[i][\"merged_sqrt_cosine\"]:.02f}',\n", " )\n", " axs[1].legend(handles=[patch1])\n", "\n", " sv.plot_spectrum(\n", - " {\"peaks\": df_cas.loc[i][\"peaks\"]},\n", - " {\"peaks\": df_cas.loc[i][\"ice_peaks\"]}\n", - " if df_cas.loc[i][\"ice_peaks\"]\n", - " else {\"peaks\": {\"mz\": [0], \"intensity\": [0]}},\n", + " {'peaks': df_cas.loc[i]['peaks']},\n", + " {'peaks': df_cas.loc[i]['ice_peaks']}\n", + " if df_cas.loc[i]['ice_peaks']\n", + " else {'peaks': {'mz': [0], 'intensity': [0]}},\n", " ax=axs[2],\n", " )\n", - " axs[2].title.set_text(f\"ICEBERG\")\n", + " axs[2].title.set_text('ICEBERG')\n", "\n", " patch2 = mpatches.Patch(\n", - " color=\"limegreen\"\n", - " if df_cas.loc[i][\"ice_sqrt_cosine\"] > df_cas.loc[i][\"merged_sqrt_cosine\"]\n", - " else \"orangered\",\n", - " label=f\"cosine {df_cas.loc[i]['ice_sqrt_cosine']:.02f}\",\n", + " color='limegreen'\n", + " if df_cas.loc[i]['ice_sqrt_cosine'] > df_cas.loc[i]['merged_sqrt_cosine']\n", + " else 'orangered',\n", + " label=f'cosine {df_cas.loc[i][\"ice_sqrt_cosine\"]:.02f}',\n", " )\n", " axs[2].legend(handles=[patch2])\n", "\n", @@ -24047,7 +24045,7 @@ } ], "source": [ - "sns.histplot(data=df_test, x=\"CE\", hue=\"lib\")\n", + "sns.histplot(data=df_test, x='CE', hue='lib')\n", "plt.show()" ] }, @@ -24073,18 +24071,18 @@ "\n", "sns.histplot(\n", " ax=axs[0],\n", - " data=df_test[df_test[\"Precursor_type\"] == \"[M+H]+\"],\n", - " x=\"spectral_sqrt_cosine\",\n", - " hue=\"lib\",\n", + " data=df_test[df_test['Precursor_type'] == '[M+H]+'],\n", + " x='spectral_sqrt_cosine',\n", + " hue='lib',\n", ")\n", - "axs[0].set_title(\"[M+H]+ Test split\")\n", + "axs[0].set_title('[M+H]+ Test split')\n", "sns.histplot(\n", " ax=axs[1],\n", - " data=df_test[df_test[\"Precursor_type\"] == \"[M-H]-\"],\n", - " x=\"spectral_sqrt_cosine\",\n", - " hue=\"lib\",\n", + " data=df_test[df_test['Precursor_type'] == '[M-H]-'],\n", + " x='spectral_sqrt_cosine',\n", + " hue='lib',\n", ")\n", - "axs[1].set_title(\"[M-H]- Test split\")\n", + "axs[1].set_title('[M-H]- Test split')\n", "plt.show()" ] }, @@ -24105,22 +24103,22 @@ } ], "source": [ - "print(df_cas.groupby(\"Precursor_type\")[\"coverage\"].median())\n", + "print(df_cas.groupby('Precursor_type')['coverage'].median())\n", "fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n", "\n", - "sns.kdeplot(ax=axs[0], data=df_cas, x=\"coverage\", hue=\"Precursor_type\", bw_adjust=0.5)\n", + "sns.kdeplot(ax=axs[0], data=df_cas, x='coverage', hue='Precursor_type', bw_adjust=0.5)\n", "sns.scatterplot(\n", - " ax=axs[1], data=df_cas, x=\"coverage\", y=\"spectral_sqrt_cosine\", hue=\"Precursor_type\"\n", + " ax=axs[1], data=df_cas, x='coverage', y='spectral_sqrt_cosine', hue='Precursor_type'\n", ")\n", "axs[0].axvline(\n", - " x=df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(),\n", - " color=\"black\",\n", - " linestyle=\"--\",\n", + " x=df_cas[df_cas['Precursor_type'] == '[M+H]+']['coverage'].median(),\n", + " color='black',\n", + " linestyle='--',\n", ")\n", "axs[1].axvline(\n", - " x=df_cas[df_cas[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(),\n", - " color=\"black\",\n", - " linestyle=\"--\",\n", + " x=df_cas[df_cas['Precursor_type'] == '[M+H]+']['coverage'].median(),\n", + " color='black',\n", + " linestyle='--',\n", ")\n", "plt.show()" ] @@ -24161,7 +24159,7 @@ ], "source": [ "sns.kdeplot(\n", - " df_cas22, x=\"CE\", y=\"spectral_sqrt_cosine\", hue=\"Precursor_type\"\n", + " df_cas22, x='CE', y='spectral_sqrt_cosine', hue='Precursor_type'\n", ") # , hue=\"lib\")\n", "plt.show()" ] @@ -24201,8 +24199,8 @@ } ], "source": [ - "df_test[\"pp_score_dif\"] = (\n", - " df_test[\"spectral_sqrt_cosine\"] - df_test[\"spectral_sqrt_cosine_wo_prec\"]\n", + "df_test['pp_score_dif'] = (\n", + " df_test['spectral_sqrt_cosine'] - df_test['spectral_sqrt_cosine_wo_prec']\n", ")" ] }, @@ -24223,7 +24221,7 @@ } ], "source": [ - "sns.boxplot(df_test, x=\"pp_score_dif\") # , hue=\"lib\")\n", + "sns.boxplot(df_test, x='pp_score_dif') # , hue=\"lib\")\n", "plt.xlim(-0.5, 0.5)\n", "plt.show()" ] @@ -24245,7 +24243,7 @@ } ], "source": [ - "sns.scatterplot(df_test, x=\"pp_score_dif\", y=\"pp\", hue=\"CE\") # , hue=\"lib\")\n", + "sns.scatterplot(df_test, x='pp_score_dif', y='pp', hue='CE') # , hue=\"lib\")\n", "plt.xlim(-0.2, 0.6)\n", "plt.show()" ] @@ -24267,7 +24265,7 @@ } ], "source": [ - "sns.histplot(df_cas22, x=\"CE\", hue=\"Precursor_type\", binwidth=3) # , hue=\"lib\")\n", + "sns.histplot(df_cas22, x='CE', hue='Precursor_type', binwidth=3) # , hue=\"lib\")\n", "plt.xlim(0, 100)\n", "plt.show()" ] @@ -24290,7 +24288,7 @@ ], "source": [ "sns.histplot(\n", - " df_cas, x=\"CE\", hue=\"Precursor_type\", hue_order=[\"[M+H]+\", \"[M-H]-\"], binwidth=3\n", + " df_cas, x='CE', hue='Precursor_type', hue_order=['[M+H]+', '[M-H]-'], binwidth=3\n", ") # , hue=\"lib\")\n", "plt.xlim(0, 100)\n", "plt.show()" @@ -24316,10 +24314,10 @@ "fig, axs = plt.subplots(1, 2, figsize=(8, 4))\n", "\n", "sns.boxplot(\n", - " ax=axs[0], data=df_test, y=\"Precursor_ppm_error\", hue=\"lib\", showfliers=False\n", + " ax=axs[0], data=df_test, y='Precursor_ppm_error', hue='lib', showfliers=False\n", ")\n", "sns.boxplot(\n", - " ax=axs[1], data=df_test, y=\"Precursor_abs_error\", hue=\"lib\", showfliers=False\n", + " ax=axs[1], data=df_test, y='Precursor_abs_error', hue='lib', showfliers=False\n", ")\n", "plt.show()" ] @@ -24366,11 +24364,11 @@ "source": [ "sns.kdeplot(\n", " data=df_test[\n", - " (df_test[\"Precursor_ppm_error\"] < 1) & (df_test[\"Precursor_ppm_error\"] < 2)\n", + " (df_test['Precursor_ppm_error'] < 1) & (df_test['Precursor_ppm_error'] < 2)\n", " ],\n", - " y=\"spectral_sqrt_cosine\",\n", - " x=\"Precursor_ppm_error\",\n", - " hue=\"lib\",\n", + " y='spectral_sqrt_cosine',\n", + " x='Precursor_ppm_error',\n", + " hue='lib',\n", ")\n", "plt.xlim(0, 2)\n", "plt.show()" @@ -24400,7 +24398,7 @@ } ], "source": [ - "df_cast.groupby(\"Precursor_type\")[\"spectral_sqrt_cosine\"].median()" + "df_cast.groupby('Precursor_type')['spectral_sqrt_cosine'].median()" ] }, { @@ -24420,26 +24418,26 @@ } ], "source": [ - "print(df_cast.groupby(\"Precursor_type\")[\"coverage\"].median())\n", + "print(df_cast.groupby('Precursor_type')['coverage'].median())\n", "fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n", "\n", - "sns.kdeplot(ax=axs[0], data=df_cast, x=\"coverage\", hue=\"Precursor_type\", bw_adjust=0.5)\n", + "sns.kdeplot(ax=axs[0], data=df_cast, x='coverage', hue='Precursor_type', bw_adjust=0.5)\n", "sns.scatterplot(\n", " ax=axs[1],\n", " data=df_cast,\n", - " x=\"coverage\",\n", - " y=\"spectral_sqrt_cosine\",\n", - " hue=\"Precursor_type\",\n", + " x='coverage',\n", + " y='spectral_sqrt_cosine',\n", + " hue='Precursor_type',\n", ")\n", "axs[0].axvline(\n", - " x=df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(),\n", - " color=\"black\",\n", - " linestyle=\"--\",\n", + " x=df_cast[df_cast['Precursor_type'] == '[M+H]+']['coverage'].median(),\n", + " color='black',\n", + " linestyle='--',\n", ")\n", "axs[1].axvline(\n", - " x=df_cast[df_cast[\"Precursor_type\"] == \"[M+H]+\"][\"coverage\"].median(),\n", - " color=\"black\",\n", - " linestyle=\"--\",\n", + " x=df_cast[df_cast['Precursor_type'] == '[M+H]+']['coverage'].median(),\n", + " color='black',\n", + " linestyle='--',\n", ")\n", "plt.show()" ] @@ -24461,11 +24459,11 @@ } ], "source": [ - "for i, m in enumerate(df_cast[\"Metabolite\"]):\n", - " for m2 in df_train[\"Metabolite\"]:\n", + "for i, m in enumerate(df_cast['Metabolite']):\n", + " for m2 in df_train['Metabolite']:\n", " if m == m2:\n", " print(i)\n", - " print(\"Violation\")\n", + " print('Violation')\n", " break" ] }, @@ -24482,14 +24480,14 @@ "metadata": {}, "outputs": [], "source": [ - "df_cas[\"lib\"] = \"CASMI 16\"\n", - "df_cas22[\"lib\"] = \"CASMI 22\"\n", - "df_cas[\"Name\"] = df_cas[\"ChallengeName\"]\n", - "df_cas22[\"Name\"] = df_cas22[\"ChallengeName\"]\n", - "df_msnlib_test[\"Name\"] = df_msnlib_test[\"NAME\"]\n", - "df_msnlib_test[\"Instrument_type\"] = \"HCD\"\n", + "df_cas['lib'] = 'CASMI 16'\n", + "df_cas22['lib'] = 'CASMI 22'\n", + "df_cas['Name'] = df_cas['ChallengeName']\n", + "df_cas22['Name'] = df_cas22['ChallengeName']\n", + "df_msnlib_test['Name'] = df_msnlib_test['NAME']\n", + "df_msnlib_test['Instrument_type'] = 'HCD'\n", "TEST = pd.concat(\n", - " [df_test[df_test[\"lib\"] != \"NIST\"], df_msnlib_test, df_cas, df_cas22],\n", + " [df_test[df_test['lib'] != 'NIST'], df_msnlib_test, df_cas, df_cas22],\n", " ignore_index=True,\n", ")" ] @@ -24512,6 +24510,7 @@ ], "source": [ "import importlib\n", + "\n", "import fiora.IO.mgfWriter as mgfWriter\n", "\n", "importlib.reload(mgfWriter)" @@ -24526,24 +24525,24 @@ "import fiora.IO.mgfWriter as mgfWriter\n", "\n", "headers = [\n", - " \"TITLE\",\n", - " \"SMILES\",\n", - " \"PRECURSORTYPE\",\n", - " \"COLLISIONENERGY\",\n", - " \"INSTRUMENTTYPE\",\n", - " \"SOURCE\",\n", + " 'TITLE',\n", + " 'SMILES',\n", + " 'PRECURSORTYPE',\n", + " 'COLLISIONENERGY',\n", + " 'INSTRUMENTTYPE',\n", + " 'SOURCE',\n", "]\n", "mgfWriter.write_mgf(\n", " TEST,\n", - " path=f\"{home}/data/archive/fiora_source_data/testing/ground_truth_spectra.mgf\",\n", + " path=f'{home}/data/archive/fiora_source_data/testing/ground_truth_spectra.mgf',\n", " write_header=True,\n", " headers=headers,\n", " header_map={\n", - " \"TITLE\": \"Name\",\n", - " \"PRECURSORTYPE\": \"Precursor_type\",\n", - " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", - " \"COLLISIONENERGY\": \"CE\",\n", - " \"SOURCE\": \"lib\",\n", + " 'TITLE': 'Name',\n", + " 'PRECURSORTYPE': 'Precursor_type',\n", + " 'INSTRUMENTTYPE': 'Instrument_type',\n", + " 'COLLISIONENERGY': 'CE',\n", + " 'SOURCE': 'lib',\n", " },\n", " annotation=False,\n", ")" @@ -24556,27 +24555,27 @@ "outputs": [], "source": [ "headers = [\n", - " \"TITLE\",\n", - " \"SMILES\",\n", - " \"PRECURSORTYPE\",\n", - " \"COLLISIONENERGY\",\n", - " \"INSTRUMENTTYPE\",\n", - " \"SOURCE\",\n", - " \"COMMENT\",\n", + " 'TITLE',\n", + " 'SMILES',\n", + " 'PRECURSORTYPE',\n", + " 'COLLISIONENERGY',\n", + " 'INSTRUMENTTYPE',\n", + " 'SOURCE',\n", + " 'COMMENT',\n", "]\n", - "TEST[\"COMMENT\"] = f'\"In silico generated spectrum by Fiora (pre-release version)\"'\n", + "TEST['COMMENT'] = '\"In silico generated spectrum by Fiora (pre-release version)\"'\n", "mgfWriter.write_mgf(\n", " TEST,\n", - " peak_tag=\"sim_peaks\",\n", - " path=f\"{home}/data/archive/fiora_source_data/testing/fiora_predicted_spectra.mgf\",\n", + " peak_tag='sim_peaks',\n", + " path=f'{home}/data/archive/fiora_source_data/testing/fiora_predicted_spectra.mgf',\n", " write_header=True,\n", " headers=headers,\n", " header_map={\n", - " \"TITLE\": \"Name\",\n", - " \"PRECURSORTYPE\": \"Precursor_type\",\n", - " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", - " \"COLLISIONENERGY\": \"CE\",\n", - " \"SOURCE\": \"lib\",\n", + " 'TITLE': 'Name',\n", + " 'PRECURSORTYPE': 'Precursor_type',\n", + " 'INSTRUMENTTYPE': 'Instrument_type',\n", + " 'COLLISIONENERGY': 'CE',\n", + " 'SOURCE': 'lib',\n", " },\n", " annotation=False,\n", ")" @@ -24588,19 +24587,19 @@ "metadata": {}, "outputs": [], "source": [ - "TEST[\"COMMENT\"] = f'\"In silico generated spectrum by ICEBERG\"'\n", + "TEST['COMMENT'] = '\"In silico generated spectrum by ICEBERG\"'\n", "mgfWriter.write_mgf(\n", - " TEST[TEST[\"Precursor_type\"] == \"[M+H]+\"],\n", - " peak_tag=\"ice_peaks\",\n", - " path=f\"{home}/data/archive/fiora_source_data/testing/iceberg_predicted_spectra.mgf\",\n", + " TEST[TEST['Precursor_type'] == '[M+H]+'],\n", + " peak_tag='ice_peaks',\n", + " path=f'{home}/data/archive/fiora_source_data/testing/iceberg_predicted_spectra.mgf',\n", " write_header=True,\n", " headers=headers,\n", " header_map={\n", - " \"TITLE\": \"Name\",\n", - " \"PRECURSORTYPE\": \"Precursor_type\",\n", - " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", - " \"COLLISIONENERGY\": \"CE\",\n", - " \"SOURCE\": \"lib\",\n", + " 'TITLE': 'Name',\n", + " 'PRECURSORTYPE': 'Precursor_type',\n", + " 'INSTRUMENTTYPE': 'Instrument_type',\n", + " 'COLLISIONENERGY': 'CE',\n", + " 'SOURCE': 'lib',\n", " },\n", " annotation=False,\n", ")" @@ -24612,19 +24611,19 @@ "metadata": {}, "outputs": [], "source": [ - "TEST[\"COMMENT\"] = f'\"In silico generated spectrum by CFM-ID (v4.4.7)\"'\n", + "TEST['COMMENT'] = '\"In silico generated spectrum by CFM-ID (v4.4.7)\"'\n", "mgfWriter.write_mgf(\n", - " TEST[~TEST[\"cfm_peaks\"].isna()],\n", - " peak_tag=\"cfm_peaks\",\n", - " path=f\"{home}/data/archive/fiora_source_data/testing/cfm-id_predicted_spectra.mgf\",\n", + " TEST[~TEST['cfm_peaks'].isna()],\n", + " peak_tag='cfm_peaks',\n", + " path=f'{home}/data/archive/fiora_source_data/testing/cfm-id_predicted_spectra.mgf',\n", " write_header=True,\n", " headers=headers,\n", " header_map={\n", - " \"TITLE\": \"Name\",\n", - " \"PRECURSORTYPE\": \"Precursor_type\",\n", - " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", - " \"COLLISIONENERGY\": \"CE\",\n", - " \"SOURCE\": \"lib\",\n", + " 'TITLE': 'Name',\n", + " 'PRECURSORTYPE': 'Precursor_type',\n", + " 'INSTRUMENTTYPE': 'Instrument_type',\n", + " 'COLLISIONENERGY': 'CE',\n", + " 'SOURCE': 'lib',\n", " },\n", " annotation=False,\n", ")" @@ -24645,24 +24644,24 @@ "metadata": {}, "outputs": [], "source": [ - "TEST[\"peaks\"] = TEST[\"peaks\"].apply(\n", + "TEST['peaks'] = TEST['peaks'].apply(\n", " lambda x: (\n", - " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " {k: v for k, v in x.items() if k != 'mz_int'} if isinstance(x, dict) else x\n", " )\n", ")\n", - "TEST[\"sim_peaks\"] = TEST[\"sim_peaks\"].apply(\n", + "TEST['sim_peaks'] = TEST['sim_peaks'].apply(\n", " lambda x: (\n", - " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " {k: v for k, v in x.items() if k != 'mz_int'} if isinstance(x, dict) else x\n", " )\n", ")\n", - "TEST[\"ice_peaks\"] = TEST[\"ice_peaks\"].apply(\n", + "TEST['ice_peaks'] = TEST['ice_peaks'].apply(\n", " lambda x: (\n", - " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " {k: v for k, v in x.items() if k != 'mz_int'} if isinstance(x, dict) else x\n", " )\n", ")\n", - "TEST[\"cfm_peaks\"] = TEST[\"cfm_peaks\"].apply(\n", + "TEST['cfm_peaks'] = TEST['cfm_peaks'].apply(\n", " lambda x: (\n", - " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " {k: v for k, v in x.items() if k != 'mz_int'} if isinstance(x, dict) else x\n", " )\n", ")" ] @@ -24917,50 +24916,50 @@ "source": [ "TEST[\n", " [\n", - " \"Name\",\n", - " \"SMILES\",\n", - " \"InChIKey\",\n", - " \"group_id\",\n", - " \"Precursor_type\",\n", - " \"Instrument_type\",\n", - " \"CE\",\n", - " \"Dataset\",\n", - " \"lib\",\n", - " \"summary\",\n", - " \"peaks\",\n", - " \"sim_peaks\",\n", - " \"ice_peaks\",\n", - " \"cfm_peaks\",\n", - " \"spectral_cosine\",\n", - " \"spectral_sqrt_cosine\",\n", - " \"spectral_sqrt_cosine_wo_prec\",\n", - " \"spectral_refl_cosine\",\n", - " \"spectral_bias\",\n", - " \"spectral_sqrt_bias\",\n", - " \"spectral_sqrt_bias_wo_prec\",\n", - " \"spectral_refl_bias\",\n", - " \"steins_cosine\",\n", - " \"steins_bias\",\n", - " \"cfm_sqrt_cosine\",\n", - " \"cfm_refl_cosine\",\n", - " \"cfm_sqrt_cosine_wo_prec\",\n", - " \"cfm_steins\",\n", - " \"ice_name\",\n", - " \"ice_peaks\",\n", - " \"ice_cosine\",\n", - " \"ice_sqrt_cosine\",\n", - " \"ice_sqrt_cosine_wo_prec\",\n", - " \"ice_refl_cosine\",\n", - " \"ice_steins\",\n", - " \"RETENTIONTIME\",\n", - " \"RT_pred\",\n", - " \"CCS\",\n", - " \"CCS_pred\",\n", - " \"coverage\",\n", - " \"tanimoto\",\n", - " \"tanimoto3\",\n", + " 'Name',\n", + " 'SMILES',\n", + " 'InChIKey',\n", + " 'group_id',\n", + " 'Precursor_type',\n", + " 'Instrument_type',\n", + " 'CE',\n", + " 'Dataset',\n", + " 'lib',\n", + " 'summary',\n", + " 'peaks',\n", + " 'sim_peaks',\n", + " 'ice_peaks',\n", + " 'cfm_peaks',\n", + " 'spectral_cosine',\n", + " 'spectral_sqrt_cosine',\n", + " 'spectral_sqrt_cosine_wo_prec',\n", + " 'spectral_refl_cosine',\n", + " 'spectral_bias',\n", + " 'spectral_sqrt_bias',\n", + " 'spectral_sqrt_bias_wo_prec',\n", + " 'spectral_refl_bias',\n", + " 'steins_cosine',\n", + " 'steins_bias',\n", + " 'cfm_sqrt_cosine',\n", + " 'cfm_refl_cosine',\n", + " 'cfm_sqrt_cosine_wo_prec',\n", + " 'cfm_steins',\n", + " 'ice_name',\n", + " 'ice_peaks',\n", + " 'ice_cosine',\n", + " 'ice_sqrt_cosine',\n", + " 'ice_sqrt_cosine_wo_prec',\n", + " 'ice_refl_cosine',\n", + " 'ice_steins',\n", + " 'RETENTIONTIME',\n", + " 'RT_pred',\n", + " 'CCS',\n", + " 'CCS_pred',\n", + " 'coverage',\n", + " 'tanimoto',\n", + " 'tanimoto3',\n", " ]\n", - "].to_csv(f\"{home}/data/archive/fiora_source_data/testing/dataframe.csv\")" + "].to_csv(f'{home}/data/archive/fiora_source_data/testing/dataframe.csv')" ] }, { @@ -24988,10 +24987,10 @@ } ], "source": [ - "df_msnlib_train[\"Name\"] = df_msnlib_train[\"NAME\"]\n", - "df_msnlib_train[\"Instrument_type\"] = \"HCD\"\n", + "df_msnlib_train['Name'] = df_msnlib_train['NAME']\n", + "df_msnlib_train['Instrument_type'] = 'HCD'\n", "TRAIN = pd.concat(\n", - " [df_train[df_train[\"lib\"] != \"NIST\"], df_msnlib_train], ignore_index=True\n", + " [df_train[df_train['lib'] != 'NIST'], df_msnlib_train], ignore_index=True\n", ")" ] }, @@ -25001,7 +25000,7 @@ "metadata": {}, "outputs": [], "source": [ - "TRAIN[\"datasplit\"] = TRAIN[\"dataset\"]" + "TRAIN['datasplit'] = TRAIN['dataset']" ] }, { @@ -25024,7 +25023,7 @@ } ], "source": [ - "TRAIN[\"lib\"].value_counts(dropna=False)" + "TRAIN['lib'].value_counts(dropna=False)" ] }, { @@ -25034,39 +25033,39 @@ "outputs": [], "source": [ "headers = [\n", - " \"TITLE\",\n", - " \"SMILES\",\n", - " \"PRECURSORTYPE\",\n", - " \"COLLISIONENERGY\",\n", - " \"INSTRUMENTTYPE\",\n", - " \"SOURCE\",\n", + " 'TITLE',\n", + " 'SMILES',\n", + " 'PRECURSORTYPE',\n", + " 'COLLISIONENERGY',\n", + " 'INSTRUMENTTYPE',\n", + " 'SOURCE',\n", "]\n", "\n", "mgfWriter.write_mgf(\n", - " TRAIN[TRAIN[\"datasplit\"] == \"training\"],\n", - " path=f\"{home}/data/archive/fiora_source_data/training/training_spectra.mgf\",\n", + " TRAIN[TRAIN['datasplit'] == 'training'],\n", + " path=f'{home}/data/archive/fiora_source_data/training/training_spectra.mgf',\n", " write_header=True,\n", " headers=headers,\n", " header_map={\n", - " \"TITLE\": \"Name\",\n", - " \"PRECURSORTYPE\": \"Precursor_type\",\n", - " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", - " \"COLLISIONENERGY\": \"CE\",\n", - " \"SOURCE\": \"lib\",\n", + " 'TITLE': 'Name',\n", + " 'PRECURSORTYPE': 'Precursor_type',\n", + " 'INSTRUMENTTYPE': 'Instrument_type',\n", + " 'COLLISIONENERGY': 'CE',\n", + " 'SOURCE': 'lib',\n", " },\n", " annotation=False,\n", ")\n", "mgfWriter.write_mgf(\n", - " TRAIN[TRAIN[\"datasplit\"] == \"validation\"],\n", - " path=f\"{home}/data/archive/fiora_source_data/training/validation_spectra.mgf\",\n", + " TRAIN[TRAIN['datasplit'] == 'validation'],\n", + " path=f'{home}/data/archive/fiora_source_data/training/validation_spectra.mgf',\n", " write_header=True,\n", " headers=headers,\n", " header_map={\n", - " \"TITLE\": \"Name\",\n", - " \"PRECURSORTYPE\": \"Precursor_type\",\n", - " \"INSTRUMENTTYPE\": \"Instrument_type\",\n", - " \"COLLISIONENERGY\": \"CE\",\n", - " \"SOURCE\": \"lib\",\n", + " 'TITLE': 'Name',\n", + " 'PRECURSORTYPE': 'Precursor_type',\n", + " 'INSTRUMENTTYPE': 'Instrument_type',\n", + " 'COLLISIONENERGY': 'CE',\n", + " 'SOURCE': 'lib',\n", " },\n", " annotation=False,\n", ")\n", @@ -25080,29 +25079,29 @@ "metadata": {}, "outputs": [], "source": [ - "TRAIN[\"peaks\"] = TRAIN[\"peaks\"].apply(\n", + "TRAIN['peaks'] = TRAIN['peaks'].apply(\n", " lambda x: (\n", - " {k: v for k, v in x.items() if k != \"mz_int\"} if isinstance(x, dict) else x\n", + " {k: v for k, v in x.items() if k != 'mz_int'} if isinstance(x, dict) else x\n", " )\n", ")\n", "\n", "TRAIN[\n", " [\n", - " \"Name\",\n", - " \"SMILES\",\n", - " \"InChIKey\",\n", - " \"group_id\",\n", - " \"datasplit\",\n", - " \"Precursor_type\",\n", - " \"Instrument_type\",\n", - " \"CE\",\n", - " \"lib\",\n", - " \"peaks\",\n", - " \"summary\",\n", - " \"RETENTIONTIME\",\n", - " \"CCS\",\n", + " 'Name',\n", + " 'SMILES',\n", + " 'InChIKey',\n", + " 'group_id',\n", + " 'datasplit',\n", + " 'Precursor_type',\n", + " 'Instrument_type',\n", + " 'CE',\n", + " 'lib',\n", + " 'peaks',\n", + " 'summary',\n", + " 'RETENTIONTIME',\n", + " 'CCS',\n", " ]\n", - "].to_csv(f\"{home}/data/archive/fiora_source_data/training/dataframe.csv\")" + "].to_csv(f'{home}/data/archive/fiora_source_data/training/dataframe.csv')" ] } ], diff --git a/notebooks/train_model.ipynb b/notebooks/train_model.ipynb index 90ce3b6..aa1980a 100644 --- a/notebooks/train_model.ipynb +++ b/notebooks/train_model.ipynb @@ -28,6 +28,7 @@ ], "source": [ "import sys\n", + "\n", "import torch\n", "\n", "seed = 42\n", @@ -36,28 +37,29 @@ "torch.set_printoptions(precision=2, sci_mode=False)\n", "\n", "\n", - "import pandas as pd\n", - "import numpy as np\n", "import ast\n", "import copy\n", "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", "# Load Modules\n", - "sys.path.append(\"..\")\n", + "sys.path.append('..')\n", "from os.path import expanduser\n", "\n", - "home = expanduser(\"~\")\n", - "from fiora.MOL.constants import DEFAULT_PPM, PPM, DEFAULT_MODES\n", - "from fiora.IO.LibraryLoader import LibraryLoader\n", - "from fiora.MOL.FragmentationTree import FragmentationTree\n", - "import fiora.visualization.spectrum_visualizer as sv\n", - "\n", - "from sklearn.metrics import r2_score\n", + "home = expanduser('~')\n", "import scipy\n", "from rdkit import RDLogger\n", + "from sklearn.metrics import r2_score\n", "\n", - "RDLogger.DisableLog(\"rdApp.*\")\n", + "import fiora.visualization.spectrum_visualizer as sv\n", + "from fiora.IO.LibraryLoader import LibraryLoader\n", + "from fiora.MOL.constants import DEFAULT_MODES, DEFAULT_PPM, PPM\n", + "from fiora.MOL.FragmentationTree import FragmentationTree\n", + "\n", + "RDLogger.DisableLog('rdApp.*')\n", "\n", - "print(f\"Working with Python {sys.version}\")" + "print(f'Working with Python {sys.version}')" ] }, { @@ -84,13 +86,13 @@ "source": [ "from typing import Literal\n", "\n", - "lib: Literal[\"NIST\", \"MSDIAL\", \"NIST/MSDIAL\", \"MSnLib\"] = \"MSnLib\" # \"MSnLib\"\n", - "print(f\"Preparing {lib} library\")\n", + "lib: Literal['NIST', 'MSDIAL', 'NIST/MSDIAL', 'MSnLib'] = 'MSnLib' # \"MSnLib\"\n", + "print(f'Preparing {lib} library')\n", "\n", "debug_mode = False # Default: False\n", "if debug_mode:\n", " print(\n", - " \"+++ This is a test run (debug mode) with a small subset of data points. Results are not representative. +++\"\n", + " '+++ This is a test run (debug mode) with a small subset of data points. Results are not representative. +++'\n", " )" ] }, @@ -102,14 +104,14 @@ "source": [ "# key map to read metadata from pandas DataFrame\n", "metadata_key_map = {\n", - " \"name\": \"Name\",\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"ionization\": \"Ionization\",\n", - " \"precursor_mz\": \"PrecursorMZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", - " \"ccs\": \"CCS\",\n", + " 'name': 'Name',\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'ionization': 'Ionization',\n", + " 'precursor_mz': 'PrecursorMZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", + " 'ccs': 'CCS',\n", "}\n", "\n", "\n", @@ -119,14 +121,14 @@ "\n", "\n", "def load_training_data():\n", - " if \"NIST\" in lib or \"MSDIAL\" in lib:\n", - " data_path: str = f\"{home}/data/metabolites/preprocessed/datasplits_Jan24.csv\"\n", - " elif lib == \"MSnLib\":\n", + " if 'NIST' in lib or 'MSDIAL' in lib:\n", + " data_path: str = f'{home}/data/metabolites/preprocessed/datasplits_Jan24.csv'\n", + " elif lib == 'MSnLib':\n", " data_path: str = (\n", - " f\"{home}/data/metabolites/preprocessed/datasplits_msnlib_v7_Sep25.csv\"\n", + " f'{home}/data/metabolites/preprocessed/datasplits_msnlib_v7_Sep25.csv'\n", " )\n", " else:\n", - " raise NameError(f\"Unknown library selected {lib=}.\")\n", + " raise NameError(f'Unknown library selected {lib=}.')\n", " L = LibraryLoader()\n", " df = L.load_from_csv(data_path)\n", " return df\n", @@ -135,12 +137,12 @@ "df = load_training_data()\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"summary\"]\n", + "dict_columns = ['peaks', 'summary']\n", "for col in dict_columns:\n", - " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace(\"nan\", \"None\")))\n", + " df[col] = df[col].apply(lambda x: ast.literal_eval(x.replace('nan', 'None')))\n", " # df[col] = df[col].apply(ast.literal_eval)\n", "\n", - "df[\"group_id\"] = df[\"group_id\"].astype(int)" + "df['group_id'] = df['group_id'].astype(int)" ] }, { @@ -159,11 +161,10 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MOL.Metabolite import Metabolite\n", "from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder\n", "from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder\n", "from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder\n", - "\n", + "from fiora.MOL.Metabolite import Metabolite\n", "\n", "CE_upper_limit = 100.0\n", "weight_upper_limit = 1000.0\n", @@ -174,51 +175,51 @@ " # df = df.iloc[5000:20000,:]\n", "\n", "overwrite_setup_features = None\n", - "if lib == \"MSnLib\":\n", + "if lib == 'MSnLib':\n", " overwrite_setup_features = {\n", - " \"instrument\": [\"HCD\"],\n", - " \"precursor_mode\": [\"[M+H]+\", \"[M-H]-\", \"[M]+\", \"[M]-\"],\n", + " 'instrument': ['HCD'],\n", + " 'precursor_mode': ['[M+H]+', '[M-H]-', '[M]+', '[M]-'],\n", " }\n", "\n", "\n", - "df[\"Metabolite\"] = df[\"SMILES\"].apply(Metabolite)\n", - "df[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df['Metabolite'] = df['SMILES'].apply(Metabolite)\n", + "df['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "node_encoder = AtomFeatureEncoder(feature_list=[\"symbol\", \"num_hydrogen\", \"ring_type\"])\n", - "bond_encoder = BondFeatureEncoder(feature_list=[\"bond_type\", \"ring_type\"])\n", + "node_encoder = AtomFeatureEncoder(feature_list=['symbol', 'num_hydrogen', 'ring_type'])\n", + "bond_encoder = BondFeatureEncoder(feature_list=['bond_type', 'ring_type'])\n", "covariate_encoder = CovariateFeatureEncoder(\n", " feature_list=[\n", - " \"collision_energy\",\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", - " \"element_composition\",\n", + " 'collision_energy',\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", + " 'element_composition',\n", " ],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "rt_encoder = CovariateFeatureEncoder(\n", " feature_list=[\n", - " \"molecular_weight\",\n", - " \"precursor_mode\",\n", - " \"instrument\",\n", - " \"element_composition\",\n", + " 'molecular_weight',\n", + " 'precursor_mode',\n", + " 'instrument',\n", + " 'element_composition',\n", " ],\n", " sets_overwrite=overwrite_setup_features,\n", ")\n", "\n", - "covariate_encoder.normalize_features[\"collision_energy\"][\"max\"] = CE_upper_limit\n", - "covariate_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", - "rt_encoder.normalize_features[\"molecular_weight\"][\"max\"] = weight_upper_limit\n", + "covariate_encoder.normalize_features['collision_energy']['max'] = CE_upper_limit\n", + "covariate_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", + "rt_encoder.normalize_features['molecular_weight']['max'] = weight_upper_limit\n", "\n", - "df[\"Metabolite\"].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", - "df.apply(lambda x: x[\"Metabolite\"].set_id(x[\"group_id\"]), axis=1)\n", + "df['Metabolite'].apply(lambda x: x.compute_graph_attributes(node_encoder, bond_encoder))\n", + "df.apply(lambda x: x['Metabolite'].set_id(x['group_id']), axis=1)\n", "\n", "# df[\"summary\"] = df.apply(lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1)\n", "df.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder, rt_encoder),\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], covariate_encoder, rt_encoder),\n", " axis=1,\n", ")\n", - "_ = df.apply(lambda x: x[\"Metabolite\"].set_loss_weight(x[\"loss_weight\"]), axis=1)" + "_ = df.apply(lambda x: x['Metabolite'].set_loss_weight(x['loss_weight']), axis=1)" ] }, { @@ -230,7 +231,7 @@ "from fiora.MOL.MetaboliteIndex import MetaboliteIndex\n", "\n", "mindex: MetaboliteIndex = MetaboliteIndex()\n", - "mindex.index_metabolites(df[\"Metabolite\"])" + "mindex.index_metabolites(df['Metabolite'])" ] }, { @@ -249,9 +250,9 @@ "source": [ "mindex.create_fragmentation_trees()\n", "list_of_mismatched_ids = mindex.add_fragmentation_trees_to_metabolite_list(\n", - " df[\"Metabolite\"], graph_mismatch_policy=\"recompute\"\n", + " df['Metabolite'], graph_mismatch_policy='recompute'\n", ")\n", - "print(f\"Total number of recomputed trees: {len(list_of_mismatched_ids)}\")" + "print(f'Total number of recomputed trees: {len(list_of_mismatched_ids)}')" ] }, { @@ -262,8 +263,8 @@ "source": [ "# df[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", "_ = df.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=x[\"ppm_peak_tolerance\"]\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=x['ppm_peak_tolerance']\n", " ),\n", " axis=1,\n", ")" @@ -283,11 +284,11 @@ } ], "source": [ - "df[\"num_peak_matches\"] = df[\"Metabolite\"].apply(\n", - " lambda x: x.match_stats[\"num_peak_matches\"]\n", + "df['num_peak_matches'] = df['Metabolite'].apply(\n", + " lambda x: x.match_stats['num_peak_matches']\n", ")\n", - "print(sum(df[\"num_peak_matches\"] < 1))\n", - "df = df[df[\"num_peak_matches\"] >= 2]" + "print(sum(df['num_peak_matches'] < 1))\n", + "df = df[df['num_peak_matches'] >= 2]" ] }, { @@ -303,18 +304,18 @@ "metadata": {}, "outputs": [], "source": [ - "casmi16_path = f\"{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv\"\n", - "casmi22_path = f\"{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv\"\n", + "casmi16_path = f'{home}/data/metabolites/CASMI_2016/casmi16_withCCS.csv'\n", + "casmi22_path = f'{home}/data/metabolites/CASMI_2022/casmi22_withCCS.csv'\n", "\n", "df_cas = pd.read_csv(casmi16_path, index_col=[0], low_memory=False)\n", "df_cas22 = pd.read_csv(casmi22_path, index_col=[0], low_memory=False)\n", "\n", "# Restore dictionary values\n", - "dict_columns = [\"peaks\", \"Candidates\"]\n", + "dict_columns = ['peaks', 'Candidates']\n", "for col in dict_columns:\n", " df_cas[col] = df_cas[col].apply(ast.literal_eval)\n", "\n", - "df_cas22[\"peaks\"] = df_cas22[\"peaks\"].apply(ast.literal_eval)" + "df_cas22['peaks'] = df_cas22['peaks'].apply(ast.literal_eval)" ] }, { @@ -343,11 +344,11 @@ "\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", - " dev = \"cuda:0\"\n", + " dev = 'cuda:0'\n", "else:\n", - " dev = \"cpu\"\n", + " dev = 'cpu'\n", "\n", - "print(f\"Running on device: {dev}\")" + "print(f'Running on device: {dev}')" ] }, { @@ -376,10 +377,10 @@ } ], "source": [ - "print(df.groupby(\"datasplit\")[\"group_id\"].unique().apply(len))\n", + "print(df.groupby('datasplit')['group_id'].unique().apply(len))\n", "\n", - "df_test = df[df[\"datasplit\"] == \"test\"]\n", - "df_train = df[df[\"datasplit\"].isin([\"training\", \"validation\"])]" + "df_test = df[df['datasplit'] == 'test']\n", + "df_train = df[df['datasplit'].isin(['training', 'validation'])]" ] }, { @@ -396,8 +397,8 @@ } ], "source": [ - "geo_data = df_train[\"Metabolite\"].apply(lambda x: x.as_geometric_data().to(dev)).values\n", - "print(f\"Prepared training/validation with {len(geo_data)} data points\")" + "geo_data = df_train['Metabolite'].apply(lambda x: x.as_geometric_data().to(dev)).values\n", + "print(f'Prepared training/validation with {len(geo_data)} data points')" ] }, { @@ -415,50 +416,50 @@ "outputs": [], "source": [ "model_params = {\n", - " \"param_tag\": \"default\",\n", + " 'param_tag': 'default',\n", " # GNN parameters\n", - " \"gnn_type\": \"RGCNConv\",\n", - " \"depth\": 10, # 8\n", - " \"hidden_dimension\": 300, # 300\n", - " \"residual_connections\": False,\n", - " \"layer_stacking\": True, # Avoid residual connections and layer stacking at the same time\n", - " \"embedding_aggregation\": \"concat\",\n", - " \"embedding_dimension\": 300, # 300,\n", - " \"subgraph_features\": True,\n", - " \"pooling_func\": \"max\", # max or avg\n", - " \"layer_norm\": True,\n", + " 'gnn_type': 'RGCNConv',\n", + " 'depth': 10, # 8\n", + " 'hidden_dimension': 300, # 300\n", + " 'residual_connections': False,\n", + " 'layer_stacking': True, # Avoid residual connections and layer stacking at the same time\n", + " 'embedding_aggregation': 'concat',\n", + " 'embedding_dimension': 300, # 300,\n", + " 'subgraph_features': True,\n", + " 'pooling_func': 'max', # max or avg\n", + " 'layer_norm': True,\n", " # Dense layers\n", - " \"dense_layers\": 2, # 2 # Number of \"hidden\" dense layers, an additional output layer is always added\n", - " \"dense_dim\": 500, # Set to None (then dense dim defaults to GNN output dimension (very large if layer stacking is active))\n", + " 'dense_layers': 2, # 2 # Number of \"hidden\" dense layers, an additional output layer is always added\n", + " 'dense_dim': 500, # Set to None (then dense dim defaults to GNN output dimension (very large if layer stacking is active))\n", " # Dropout\n", - " \"input_dropout\": 0.25, # 0.2,\n", - " \"latent_dropout\": 0.25, # 0.1,\n", + " 'input_dropout': 0.25, # 0.2,\n", + " 'latent_dropout': 0.25, # 0.1,\n", " # Dimensions\n", - " \"node_feature_layout\": node_encoder.feature_numbers,\n", - " \"edge_feature_layout\": bond_encoder.feature_numbers,\n", - " \"static_feature_dimension\": geo_data[0][\"static_edge_features\"].shape[1],\n", - " \"static_rt_feature_dimension\": geo_data[0][\"static_rt_features\"].shape[1],\n", - " \"output_dimension\": len(DEFAULT_MODES) * 2, # per edge\n", + " 'node_feature_layout': node_encoder.feature_numbers,\n", + " 'edge_feature_layout': bond_encoder.feature_numbers,\n", + " 'static_feature_dimension': geo_data[0]['static_edge_features'].shape[1],\n", + " 'static_rt_feature_dimension': geo_data[0]['static_rt_features'].shape[1],\n", + " 'output_dimension': len(DEFAULT_MODES) * 2, # per edge\n", " # Keep track of how features are encoded\n", - " \"atom_features\": node_encoder.feature_list,\n", - " \"atom_features\": bond_encoder.feature_list,\n", - " \"setup_features\": covariate_encoder.feature_list,\n", - " \"setup_features_categorical_set\": covariate_encoder.categorical_sets,\n", - " \"rt_features\": rt_encoder.feature_list,\n", + " 'atom_features': node_encoder.feature_list,\n", + " 'atom_features': bond_encoder.feature_list,\n", + " 'setup_features': covariate_encoder.feature_list,\n", + " 'setup_features_categorical_set': covariate_encoder.categorical_sets,\n", + " 'rt_features': rt_encoder.feature_list,\n", " # Set default flags (May be overwritten below)\n", - " \"prepare_additional_layers\": False,\n", - " \"rt_supported\": False,\n", - " \"ccs_supported\": False,\n", - " \"version\": \"x.x.x\",\n", + " 'prepare_additional_layers': False,\n", + " 'rt_supported': False,\n", + " 'ccs_supported': False,\n", + " 'version': 'x.x.x',\n", "}\n", "training_params = {\n", - " \"epochs\": 300 if not debug_mode else 10,\n", - " \"batch_size\": 32, # 256, # 256\n", + " 'epochs': 300 if not debug_mode else 10,\n", + " 'batch_size': 32, # 256, # 256\n", " #'train_val_split': 0.90,\n", - " \"learning_rate\": 2e-4, # 4e-4,\n", - " \"weight_decay\": 1e-5, # 1e-4,\n", - " \"with_RT\": False, # Turn off RT/CCS for initial trainings round\n", - " \"with_CCS\": False,\n", + " 'learning_rate': 2e-4, # 4e-4,\n", + " 'weight_decay': 1e-5, # 1e-4,\n", + " 'with_RT': False, # Turn off RT/CCS for initial trainings round\n", + " 'with_CCS': False,\n", "}" ] }, @@ -481,7 +482,7 @@ "model_snapshot = FioraModel(model_params)\n", "# Print num of parameters of model\n", "num_params = sum(p.numel() for p in model_snapshot.parameters() if p.requires_grad)\n", - "print(f\"Number of trainable parameters: {num_params:,}\")" + "print(f'Number of trainable parameters: {num_params:,}')" ] }, { @@ -512,9 +513,6 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "\n", - "\n", "# Subsample training data\n", "def subsample_keys(train_keys, val_keys, down_to_fraction: float):\n", " train_sample = np.random.choice(\n", @@ -532,29 +530,30 @@ "metadata": {}, "outputs": [], "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", "from fiora.GNN.fabric_training import train_fabric_loop\n", "from fiora.GNN.FioraModel import FioraModel\n", "from fiora.GNN.Losses import (\n", - " WeightedMSELoss,\n", - " WeightedMSEMetric,\n", - " WeightedMAELoss,\n", - " WeightedMAEMetric,\n", " GraphwiseKLLoss,\n", " GraphwiseKLLossMetric,\n", + " WeightedMAELoss,\n", + " WeightedMAEMetric,\n", + " WeightedMSELoss,\n", + " WeightedMSEMetric,\n", ")\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", - "from sklearn.model_selection import train_test_split\n", "\n", "fiora = SimulationFramework(None, dev=dev)\n", "# fiora = SimulationFramework(None, dev=dev, with_RT=training_params[\"with_RT\"], with_CCS=training_params[\"with_CCS\"])\n", - "np.seterr(invalid=\"ignore\")\n", - "tag = \"training\"\n", + "np.seterr(invalid='ignore')\n", + "tag = 'training'\n", "val_interval = 1\n", "metric_dict = {\n", - " \"mse\": GraphwiseKLLossMetric\n", + " 'mse': GraphwiseKLLossMetric\n", "} # reduction=\"mean\" by default #WeightedMSEMetric\n", "# loss_fn = WeightedMSELoss() # WeightedMSELoss()\n", - "loss_fn = GraphwiseKLLoss(reduction=\"mean\")\n", + "loss_fn = GraphwiseKLLoss(reduction='mean')\n", "all_together = False\n", "down_sample = False\n", "\n", @@ -573,7 +572,7 @@ " train_set = set(int(x) for x in train_keys)\n", " val_set = set(int(x) for x in val_keys)\n", " else:\n", - " group_ids = np.array([int(getattr(x, \"group_id\")) for x in geo_data])\n", + " group_ids = np.array([int(getattr(x, 'group_id')) for x in geo_data])\n", " keys = np.unique(group_ids)\n", " tr, va = train_test_split(\n", " keys, test_size=1 - train_val_split, random_state=seed\n", @@ -581,8 +580,8 @@ " train_set = set(int(x) for x in tr)\n", " val_set = set(int(x) for x in va)\n", "\n", - " train_data = [x for x in geo_data if int(getattr(x, \"group_id\")) in train_set]\n", - " val_data = [x for x in geo_data if int(getattr(x, \"group_id\")) in val_set]\n", + " train_data = [x for x in geo_data if int(getattr(x, 'group_id')) in train_set]\n", + " val_data = [x for x in geo_data if int(getattr(x, 'group_id')) in val_set]\n", " return train_data, val_data\n", "\n", "\n", @@ -590,7 +589,7 @@ " continue_with_model=None,\n", " model_params=model_params,\n", " training_params=training_params,\n", - " tag=\"\",\n", + " tag='',\n", "):\n", " if continue_with_model:\n", " model = continue_with_model.to(dev)\n", @@ -598,21 +597,21 @@ " model = FioraModel(model_params).to(dev)\n", "\n", " # y_label = 'compiled_probsSQRT' # y_label = 'compiled_probsALL'\n", - " y_label = \"compiled_probsALL\"\n", + " y_label = 'compiled_probsALL'\n", " optimizer = torch.optim.Adam(\n", " model.parameters(),\n", - " lr=training_params[\"learning_rate\"],\n", - " weight_decay=training_params[\"weight_decay\"],\n", + " lr=training_params['learning_rate'],\n", + " weight_decay=training_params['weight_decay'],\n", " )\n", - " save_path = f\"../../checkpoint_{tag}.best.pt\"\n", + " save_path = f'../../checkpoint_{tag}.best.pt'\n", "\n", " if all_together:\n", " train_data, val_data = geo_data, []\n", " scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n", " else:\n", " train_keys, val_keys = (\n", - " df[df[\"datasplit\"] == \"training\"][\"group_id\"].unique(),\n", - " df[df[\"datasplit\"] == \"validation\"][\"group_id\"].unique(),\n", + " df[df['datasplit'] == 'training']['group_id'].unique(),\n", + " df[df['datasplit'] == 'validation']['group_id'].unique(),\n", " )\n", " if down_sample:\n", " train_fraction = 0.10\n", @@ -620,13 +619,13 @@ " train_keys, val_keys, train_fraction\n", " ) # Downsample training data for test\n", " print(\n", - " f\"Sample down to {train_fraction * 100}% with {len(train_keys)} training and {len(val_keys)} validation compounds \"\n", + " f'Sample down to {train_fraction * 100}% with {len(train_keys)} training and {len(val_keys)} validation compounds '\n", " )\n", " train_data, val_data = split_geo_by_group(\n", " geo_data, train_keys=train_keys, val_keys=val_keys, seed=seed\n", " )\n", " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", - " optimizer, patience=8, factor=0.5, mode=\"min\"\n", + " optimizer, patience=8, factor=0.5, mode='min'\n", " )\n", "\n", " checkpoints, history = train_fabric_loop(\n", @@ -637,20 +636,20 @@ " metric_dict=metric_dict,\n", " y_label=y_label,\n", " device=dev,\n", - " batch_size=training_params[\"batch_size\"],\n", + " batch_size=training_params['batch_size'],\n", " num_workers=0,\n", - " epochs=training_params[\"epochs\"],\n", + " epochs=training_params['epochs'],\n", " val_every=val_interval,\n", - " learning_rate=training_params[\"learning_rate\"],\n", - " weight_decay=training_params[\"weight_decay\"],\n", - " scheduler_name=\"none\",\n", + " learning_rate=training_params['learning_rate'],\n", + " weight_decay=training_params['weight_decay'],\n", + " scheduler_name='none',\n", " scheduler_patience=8,\n", " scheduler_factor=0.5,\n", - " with_rt=training_params[\"with_RT\"],\n", - " with_ccs=training_params[\"with_CCS\"],\n", + " with_rt=training_params['with_RT'],\n", + " with_ccs=training_params['with_CCS'],\n", " rt_metric=False,\n", " use_validation_mask=False,\n", - " validation_mask_name=\"validation_mask\",\n", + " validation_mask_name='validation_mask',\n", " output_path=save_path,\n", " optimizer=optimizer,\n", " scheduler=scheduler,\n", @@ -664,7 +663,7 @@ " return fiora.simulate_all(DF, model)\n", "\n", "\n", - "def test_model(model, DF, score=\"spectral_sqrt_cosine\", return_df=False):\n", + "def test_model(model, DF, score='spectral_sqrt_cosine', return_df=False):\n", " dft = simulate_all(model, DF)\n", "\n", " if return_df:\n", @@ -683,33 +682,33 @@ " base_model_params,\n", " base_training_params,\n", " store_models: bool = False,\n", - " prefix=\"config\",\n", + " prefix='config',\n", "):\n", " results = []\n", "\n", " for i, param_override in enumerate(param_grid):\n", - " print(f\"Running configuration {i + 1}/{len(param_grid)}...\")\n", + " print(f'Running configuration {i + 1}/{len(param_grid)}...')\n", "\n", " # Update base parameters with overrides\n", " model_params = base_model_params.copy()\n", " training_params = base_training_params.copy()\n", "\n", - " model_params.update(param_override.get(\"model_params\", {}))\n", - " training_params.update(param_override.get(\"training_params\", {}))\n", + " model_params.update(param_override.get('model_params', {}))\n", + " training_params.update(param_override.get('training_params', {}))\n", "\n", " # Train the model with the updated parameters\n", " try:\n", " model, checkpoints, history = train_new_model(\n", " model_params=model_params,\n", " training_params=training_params,\n", - " tag=f\"{prefix}_{i + 1}\",\n", + " tag=f'{prefix}_{i + 1}',\n", " )\n", " results.append(\n", " {\n", - " \"config\": param_override,\n", - " \"model\": model if store_models else None,\n", - " \"checkpoints\": checkpoints,\n", - " \"history\": history,\n", + " 'config': param_override,\n", + " 'model': model if store_models else None,\n", + " 'checkpoints': checkpoints,\n", + " 'history': history,\n", " }\n", " )\n", "\n", @@ -718,8 +717,8 @@ " torch.cuda.empty_cache()\n", "\n", " except Exception as e:\n", - " print(f\"Error in configuration {i + 1}: {e}\")\n", - " results.append({\"config\": param_override, \"error\": str(e)})\n", + " print(f'Error in configuration {i + 1}: {e}')\n", + " results.append({'config': param_override, 'error': str(e)})\n", "\n", " return results" ] @@ -1520,13 +1519,13 @@ "source": [ "import torch.multiprocessing as mp\n", "\n", - "mp.set_start_method(\"spawn\", force=True)\n", + "mp.set_start_method('spawn', force=True)\n", "GRID_SEARCH = True\n", "if GRID_SEARCH:\n", " param_grid = [\n", " # Run three repeats of the default model\n", - " {\"model_params\": {}, \"training_params\": {}},\n", - " {\"model_params\": {}, \"training_params\": {}},\n", + " {'model_params': {}, 'training_params': {}},\n", + " {'model_params': {}, 'training_params': {}},\n", " ]\n", "\n", " grid_results = grid_search(\n", @@ -1535,11 +1534,11 @@ "\n", " # Analyze results\n", " for result in grid_results:\n", - " print(result[\"config\"])\n", - " if \"error\" in result:\n", - " print(f\"Error: {result['error']}\")\n", + " print(result['config'])\n", + " if 'error' in result:\n", + " print(f'Error: {result[\"error\"]}')\n", " else:\n", - " print(f\"Checkpoints: {result['checkpoints']}\")" + " print(f'Checkpoints: {result[\"checkpoints\"]}')" ] }, { @@ -1548,12 +1547,12 @@ "metadata": {}, "outputs": [], "source": [ - "if \"model\" in locals():\n", + "if 'model' in locals():\n", " del model\n", " torch.cuda.empty_cache()\n", "\n", "if not GRID_SEARCH:\n", - " print(f\"Training model\")\n", + " print('Training model')\n", " model, checkpoints, history = train_new_model() # continue_with_model=model)" ] }, @@ -1573,23 +1572,21 @@ } ], "source": [ - "import copy\n", - "\n", "best_result = None\n", "if not GRID_SEARCH:\n", " print(checkpoints)\n", " model_at_last_epoch = copy.deepcopy(model)\n", "\n", "else:\n", - " best_result = min(grid_results, key=lambda x: x[\"checkpoints\"][\"sqrt_val_loss\"])\n", - " model = best_result[\"model\"]\n", - " checkpoints = best_result[\"checkpoints\"]\n", - " history = best_result[\"history\"]\n", + " best_result = min(grid_results, key=lambda x: x['checkpoints']['sqrt_val_loss'])\n", + " model = best_result['model']\n", + " checkpoints = best_result['checkpoints']\n", + " history = best_result['history']\n", " print(\n", - " f\"Best model found with val_sqrt_error: {best_result['checkpoints']['sqrt_val_loss']}\"\n", + " f'Best model found with val_sqrt_error: {best_result[\"checkpoints\"][\"sqrt_val_loss\"]}'\n", " )\n", - " print(\"Parameter overrides for the best model:\")\n", - " print(best_result[\"config\"])" + " print('Parameter overrides for the best model:')\n", + " print(best_result['config'])" ] }, { @@ -1617,78 +1614,78 @@ } ], "source": [ - "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", "\n", "# Convert numpy arrays to scalars if they are single-element arrays\n", - "history[\"train_error\"] = [\n", + "history['train_error'] = [\n", " error.item() if isinstance(error, np.ndarray) and error.size == 1 else error\n", - " for error in history[\"train_error\"]\n", + " for error in history['train_error']\n", "]\n", - "history[\"val_error\"] = [\n", + "history['val_error'] = [\n", " error.item() if isinstance(error, np.ndarray) and error.size == 1 else error\n", - " for error in history[\"val_error\"]\n", + " for error in history['val_error']\n", "]\n", - "history[\"lr\"] = [\n", + "history['lr'] = [\n", " lr.item() if isinstance(lr, np.ndarray) and lr.size == 1 else lr\n", - " for lr in history[\"lr\"]\n", + " for lr in history['lr']\n", "]\n", "\n", "# Create a DataFrame from the tracker dictionary\n", "tracker_df = pd.DataFrame(\n", " {\n", - " \"epoch\": history[\"epoch\"],\n", - " \"train_rmse\": history[\"sqrt_train_error\"],\n", - " \"val_rmse\": history[\"sqrt_val_error\"],\n", - " \"lr\": history[\"lr\"],\n", + " 'epoch': history['epoch'],\n", + " 'train_rmse': history['sqrt_train_error'],\n", + " 'val_rmse': history['sqrt_val_error'],\n", + " 'lr': history['lr'],\n", " }\n", ")\n", "\n", "# Plot the training and validation loss\n", "plt.figure(figsize=(10, 5))\n", "sns.lineplot(\n", - " data=tracker_df, x=\"epoch\", y=\"train_rmse\", label=\"Training RMSE\", color=\"blue\"\n", + " data=tracker_df, x='epoch', y='train_rmse', label='Training RMSE', color='blue'\n", ")\n", "sns.lineplot(\n", - " data=tracker_df, x=\"epoch\", y=\"val_rmse\", label=\"Validation RMSE\", color=\"orange\"\n", + " data=tracker_df, x='epoch', y='val_rmse', label='Validation RMSE', color='orange'\n", ")\n", "\n", "# Highlight the epochs where the learning rate changes\n", "previous_lr = None\n", "for _, row in tracker_df.iterrows():\n", - " current_lr = row[\"lr\"]\n", + " current_lr = row['lr']\n", " if current_lr != previous_lr:\n", - " epoch = row[\"epoch\"]\n", - " val_loss_at_epoch = row[\"val_rmse\"]\n", + " epoch = row['epoch']\n", + " val_loss_at_epoch = row['val_rmse']\n", " plt.scatter(\n", " epoch,\n", " val_loss_at_epoch + 0.0001,\n", - " color=\"black\",\n", - " marker=\"v\",\n", - " label=\"LR Change\" if previous_lr is None else \"\",\n", + " color='black',\n", + " marker='v',\n", + " label='LR Change' if previous_lr is None else '',\n", " )\n", " plt.text(\n", " epoch,\n", " val_loss_at_epoch + 0.0002,\n", - " f\"LR: {current_lr:1.0e}\",\n", - " color=\"black\",\n", - " ha=\"center\",\n", + " f'LR: {current_lr:1.0e}',\n", + " color='black',\n", + " ha='center',\n", " fontsize=8,\n", " )\n", " previous_lr = current_lr\n", "\n", - "plt.xlabel(\"Epoch\")\n", - "plt.ylabel(\"RMSE\")\n", + "plt.xlabel('Epoch')\n", + "plt.ylabel('RMSE')\n", "# plt.ylim(0, tracker_df[\"val_rmse\"].max() + 0.004)\n", - "plt.title(\"Training and Validation Loss Over Epochs\")\n", + "plt.title('Training and Validation Loss Over Epochs')\n", "plt.legend()\n", "plt.show()\n", - "min_train_error = min(history[\"sqrt_train_error\"])\n", - "min_val_error = min(history[\"sqrt_val_error\"])\n", - "epoch_min_train_error = history[\"epoch\"][np.argmin(history[\"sqrt_train_error\"])]\n", - "epoch_min_val_error = history[\"epoch\"][np.argmin(history[\"sqrt_val_error\"])]\n", - "print(f\"Minimum Training RMSE: {min_train_error:.5f} (Epoch {epoch_min_train_error})\")\n", - "print(f\"Minimum Validation RMSE: {min_val_error:.5f} (Epoch {epoch_min_val_error})\")" + "min_train_error = min(history['sqrt_train_error'])\n", + "min_val_error = min(history['sqrt_val_error'])\n", + "epoch_min_train_error = history['epoch'][np.argmin(history['sqrt_train_error'])]\n", + "epoch_min_val_error = history['epoch'][np.argmin(history['sqrt_val_error'])]\n", + "print(f'Minimum Training RMSE: {min_train_error:.5f} (Epoch {epoch_min_train_error})')\n", + "print(f'Minimum Validation RMSE: {min_val_error:.5f} (Epoch {epoch_min_val_error})')" ] }, { @@ -1699,36 +1696,36 @@ "source": [ "from fiora.MOL.collision_energy import NCE_to_eV\n", "\n", - "df_cas[\"RETENTIONTIME\"] = df_cas[\"RTINSECONDS\"] / 60.0\n", - "df_cas[\"Metabolite\"] = df_cas[\"SMILES\"].apply(Metabolite)\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.create_molecular_structure_graph())\n", + "df_cas['RETENTIONTIME'] = df_cas['RTINSECONDS'] / 60.0\n", + "df_cas['Metabolite'] = df_cas['SMILES'].apply(Metabolite)\n", + "df_cas['Metabolite'].apply(lambda x: x.create_molecular_structure_graph())\n", "\n", - "df_cas[\"Metabolite\"].apply(\n", + "df_cas['Metabolite'].apply(\n", " lambda x: x.compute_graph_attributes(node_encoder, bond_encoder)\n", ")\n", - "df_cas[\"CE\"] = 20.0 # actually stepped 20/35/50\n", - "df_cas[\"Instrument_type\"] = \"HCD\" # CHECK if correct Orbitrap\n", + "df_cas['CE'] = 20.0 # actually stepped 20/35/50\n", + "df_cas['Instrument_type'] = 'HCD' # CHECK if correct Orbitrap\n", "\n", "metadata_key_map16 = {\n", - " \"collision_energy\": \"CE\",\n", - " \"instrument\": \"Instrument_type\",\n", - " \"precursor_mz\": \"PRECURSOR_MZ\",\n", - " \"precursor_mode\": \"Precursor_type\",\n", - " \"retention_time\": \"RETENTIONTIME\",\n", + " 'collision_energy': 'CE',\n", + " 'instrument': 'Instrument_type',\n", + " 'precursor_mz': 'PRECURSOR_MZ',\n", + " 'precursor_mode': 'Precursor_type',\n", + " 'retention_time': 'RETENTIONTIME',\n", "}\n", "\n", - "df_cas[\"summary\"] = df_cas.apply(\n", + "df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", ")\n", "df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(x[\"summary\"], covariate_encoder), axis=1\n", + " lambda x: x['Metabolite'].add_metadata(x['summary'], covariate_encoder), axis=1\n", ")\n", "\n", "# Fragmentation\n", - "df_cas[\"Metabolite\"].apply(lambda x: x.fragment_MOL(depth=1))\n", + "df_cas['Metabolite'].apply(lambda x: x.fragment_MOL(depth=1))\n", "_ = df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].match_fragments_to_peaks(\n", - " x[\"peaks\"][\"mz\"], x[\"peaks\"][\"intensity\"], tolerance=100 * PPM\n", + " lambda x: x['Metabolite'].match_fragments_to_peaks(\n", + " x['peaks']['mz'], x['peaks']['intensity'], tolerance=100 * PPM\n", " ),\n", " axis=1,\n", ") # Optional: use mz_cut instead\n", @@ -1767,117 +1764,116 @@ "metadata": {}, "outputs": [], "source": [ - "from fiora.MOL.collision_energy import NCE_to_eV\n", + "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "from fiora.MS.spectral_scores import (\n", + " reweighted_dot,\n", " spectral_cosine,\n", " spectral_reflection_cosine,\n", - " reweighted_dot,\n", ")\n", - "from fiora.MS.ms_utility import merge_annotated_spectrum\n", "\n", "\n", - "def test_cas16(model, df_cas=df_cas, score=\"merged_sqrt_cosine\", return_df=False):\n", + "def test_cas16(model, df_cas=df_cas, score='merged_sqrt_cosine', return_df=False):\n", "\n", - " df_cas[\"NCE\"] = 20.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 20.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step1_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step1_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(\n", - " x[\"summary\"], covariate_encoder, rt_encoder\n", + " lambda x: x['Metabolite'].add_metadata(\n", + " x['summary'], covariate_encoder, rt_encoder\n", " ),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_20\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_20')\n", "\n", - " df_cas[\"NCE\"] = 35.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 35.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step2_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step2_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(\n", - " x[\"summary\"], covariate_encoder, rt_encoder\n", + " lambda x: x['Metabolite'].add_metadata(\n", + " x['summary'], covariate_encoder, rt_encoder\n", " ),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_35\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_35')\n", "\n", - " df_cas[\"NCE\"] = 50.0 # actually stepped NCE 20/35/50\n", - " df_cas[\"CE\"] = df_cas[[\"NCE\", \"PRECURSOR_MZ\"]].apply(\n", - " lambda x: NCE_to_eV(x[\"NCE\"], x[\"PRECURSOR_MZ\"]), axis=1\n", + " df_cas['NCE'] = 50.0 # actually stepped NCE 20/35/50\n", + " df_cas['CE'] = df_cas[['NCE', 'PRECURSOR_MZ']].apply(\n", + " lambda x: NCE_to_eV(x['NCE'], x['PRECURSOR_MZ']), axis=1\n", " )\n", - " df_cas[\"step3_CE\"] = df_cas[\"CE\"]\n", - " df_cas[\"summary\"] = df_cas.apply(\n", + " df_cas['step3_CE'] = df_cas['CE']\n", + " df_cas['summary'] = df_cas.apply(\n", " lambda x: {key: x[name] for key, name in metadata_key_map16.items()}, axis=1\n", " )\n", " df_cas.apply(\n", - " lambda x: x[\"Metabolite\"].add_metadata(\n", - " x[\"summary\"], covariate_encoder, rt_encoder\n", + " lambda x: x['Metabolite'].add_metadata(\n", + " x['summary'], covariate_encoder, rt_encoder\n", " ),\n", " axis=1,\n", " )\n", - " df_cas = fiora.simulate_all(df_cas, model, suffix=\"_50\")\n", + " df_cas = fiora.simulate_all(df_cas, model, suffix='_50')\n", "\n", - " df_cas[\"avg_CE\"] = (\n", - " df_cas[\"step1_CE\"] + df_cas[\"step2_CE\"] + df_cas[\"step3_CE\"]\n", + " df_cas['avg_CE'] = (\n", + " df_cas['step1_CE'] + df_cas['step2_CE'] + df_cas['step3_CE']\n", " ) / 3\n", "\n", - " df_cas[\"merged_peaks\"] = df_cas.apply(\n", + " df_cas['merged_peaks'] = df_cas.apply(\n", " lambda x: merge_annotated_spectrum(\n", - " merge_annotated_spectrum(x[\"sim_peaks_20\"], x[\"sim_peaks_35\"]),\n", - " x[\"sim_peaks_50\"],\n", + " merge_annotated_spectrum(x['sim_peaks_20'], x['sim_peaks_35']),\n", + " x['sim_peaks_50'],\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['merged_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"merged_sqrt_cosine\"] = df_cas.apply(\n", - " lambda x: spectral_cosine(x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt),\n", + " df_cas['merged_sqrt_cosine'] = df_cas.apply(\n", + " lambda x: spectral_cosine(x['peaks'], x['merged_peaks'], transform=np.sqrt),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_sqrt_cosine_wo_prec\"] = df_cas.apply(\n", + " df_cas['merged_sqrt_cosine_wo_prec'] = df_cas.apply(\n", " lambda x: spectral_cosine(\n", - " x[\"peaks\"],\n", - " x[\"merged_peaks\"],\n", + " x['peaks'],\n", + " x['merged_peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(\n", - " x[\"Metabolite\"].metadata[\"precursor_mode\"]\n", + " remove_mz=x['Metabolite'].get_theoretical_precursor_mz(\n", + " x['Metabolite'].metadata['precursor_mode']\n", " ),\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_refl_cosine\"] = df_cas.apply(\n", + " df_cas['merged_refl_cosine'] = df_cas.apply(\n", " lambda x: spectral_reflection_cosine(\n", - " x[\"peaks\"], x[\"merged_peaks\"], transform=np.sqrt\n", + " x['peaks'], x['merged_peaks'], transform=np.sqrt\n", " ),\n", " axis=1,\n", " )\n", - " df_cas[\"merged_steins\"] = df_cas.apply(\n", - " lambda x: reweighted_dot(x[\"peaks\"], x[\"merged_peaks\"]), axis=1\n", + " df_cas['merged_steins'] = df_cas.apply(\n", + " lambda x: reweighted_dot(x['peaks'], x['merged_peaks']), axis=1\n", " )\n", - " df_cas[\"spectral_sqrt_cosine\"] = df_cas[\n", - " \"merged_sqrt_cosine\"\n", + " df_cas['spectral_sqrt_cosine'] = df_cas[\n", + " 'merged_sqrt_cosine'\n", " ] # just remember it is merged\n", - " df_cas[\"spectral_sqrt_cosine_wo_prec\"] = df_cas[\n", - " \"merged_sqrt_cosine_wo_prec\"\n", + " df_cas['spectral_sqrt_cosine_wo_prec'] = df_cas[\n", + " 'merged_sqrt_cosine_wo_prec'\n", " ] # just remember it is merged\n", "\n", - " df_cas[\"coverage\"] = df_cas[\"Metabolite\"].apply(lambda x: x.match_stats[\"coverage\"])\n", - " if hasattr(model, \"rt_module\"):\n", - " df_cas[\"RT_pred\"] = df_cas[\"RT_pred_35\"]\n", - " df_cas[\"RT_dif\"] = df_cas[\"RT_dif_35\"]\n", - " if hasattr(model, \"ccs_module\"):\n", - " df_cas[\"CCS_pred\"] = df_cas[\"CCS_pred_35\"]\n", - " df_cas[\"library\"] = \"CASMI-16\"\n", + " df_cas['coverage'] = df_cas['Metabolite'].apply(lambda x: x.match_stats['coverage'])\n", + " if hasattr(model, 'rt_module'):\n", + " df_cas['RT_pred'] = df_cas['RT_pred_35']\n", + " df_cas['RT_dif'] = df_cas['RT_dif_35']\n", + " if hasattr(model, 'ccs_module'):\n", + " df_cas['CCS_pred'] = df_cas['CCS_pred_35']\n", + " df_cas['library'] = 'CASMI-16'\n", "\n", " if return_df:\n", " return df_cas\n", @@ -1892,11 +1888,11 @@ "outputs": [], "source": [ "model = (\n", - " FioraModel.load(checkpoints[\"file\"]).to(dev)\n", + " FioraModel.load(checkpoints['file']).to(dev)\n", " if not GRID_SEARCH\n", - " else FioraModel.load(best_result[\"checkpoints\"][\"file\"]).to(dev)\n", + " else FioraModel.load(best_result['checkpoints']['file']).to(dev)\n", ")\n", - "df_val = df_train[df_train[\"datasplit\"] == \"validation\"]\n", + "df_val = df_train[df_train['datasplit'] == 'validation']\n", "\n", "df_val = test_model(model, df_val, return_df=True)\n", "df_test = test_model(model, df_test, return_df=True)\n", @@ -1910,21 +1906,20 @@ "outputs": [], "source": [ "from fiora.MOL.constants import DEFAULT_DALTON\n", - "from fiora.MS.spectral_scores import spectral_cosine\n", "\n", "\n", "def construct_explained_peaks(df, tolerance):\n", " explained_peaks_list = []\n", "\n", " for _, row in df.iterrows():\n", - " peaks = row[\"peaks\"]\n", - " metabolite = row[\"Metabolite\"]\n", + " peaks = row['peaks']\n", + " metabolite = row['Metabolite']\n", " peak_matches = metabolite.peak_matches\n", "\n", " explained_mz = []\n", " explained_intensity = []\n", "\n", - " for mz, intensity in zip(peaks[\"mz\"], peaks[\"intensity\"]):\n", + " for mz, intensity in zip(peaks['mz'], peaks['intensity']):\n", " for matched_mz in peak_matches.keys():\n", " if abs(mz - matched_mz) <= tolerance:\n", " explained_mz.append(mz)\n", @@ -1932,54 +1927,54 @@ " break # Stop checking once a match is found\n", "\n", " explained_peaks_list.append(\n", - " {\"mz\": explained_mz, \"intensity\": explained_intensity}\n", + " {'mz': explained_mz, 'intensity': explained_intensity}\n", " )\n", "\n", - " df[\"explained_peaks\"] = explained_peaks_list\n", + " df['explained_peaks'] = explained_peaks_list\n", " return df\n", "\n", "\n", "df_test = construct_explained_peaks(df_test, DEFAULT_DALTON)\n", - "df_test[\"explained_sqrt_cosine\"] = df_test.apply(\n", - " lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt),\n", + "df_test['explained_sqrt_cosine'] = df_test.apply(\n", + " lambda x: spectral_cosine(x['explained_peaks'], x['peaks'], transform=np.sqrt),\n", " axis=1,\n", ")\n", - "df_test[\"explained_sqrt_cosine_wo_prec\"] = df_test.apply(\n", + "df_test['explained_sqrt_cosine_wo_prec'] = df_test.apply(\n", " lambda x: spectral_cosine(\n", - " x[\"explained_peaks\"],\n", - " x[\"peaks\"],\n", + " x['explained_peaks'],\n", + " x['peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(),\n", + " remove_mz=x['Metabolite'].get_theoretical_precursor_mz(),\n", " ),\n", " axis=1,\n", ")\n", "\n", "df_val = construct_explained_peaks(df_val, DEFAULT_DALTON)\n", - "df_val[\"explained_sqrt_cosine\"] = df_val.apply(\n", - " lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt),\n", + "df_val['explained_sqrt_cosine'] = df_val.apply(\n", + " lambda x: spectral_cosine(x['explained_peaks'], x['peaks'], transform=np.sqrt),\n", " axis=1,\n", ")\n", - "df_val[\"explained_sqrt_cosine_wo_prec\"] = df_val.apply(\n", + "df_val['explained_sqrt_cosine_wo_prec'] = df_val.apply(\n", " lambda x: spectral_cosine(\n", - " x[\"explained_peaks\"],\n", - " x[\"peaks\"],\n", + " x['explained_peaks'],\n", + " x['peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(),\n", + " remove_mz=x['Metabolite'].get_theoretical_precursor_mz(),\n", " ),\n", " axis=1,\n", ")\n", "\n", "df_cas16 = construct_explained_peaks(df_cas16, DEFAULT_DALTON)\n", - "df_cas16[\"explained_sqrt_cosine\"] = df_cas16.apply(\n", - " lambda x: spectral_cosine(x[\"explained_peaks\"], x[\"peaks\"], transform=np.sqrt),\n", + "df_cas16['explained_sqrt_cosine'] = df_cas16.apply(\n", + " lambda x: spectral_cosine(x['explained_peaks'], x['peaks'], transform=np.sqrt),\n", " axis=1,\n", ")\n", - "df_cas16[\"explained_sqrt_cosine_wo_prec\"] = df_cas16.apply(\n", + "df_cas16['explained_sqrt_cosine_wo_prec'] = df_cas16.apply(\n", " lambda x: spectral_cosine(\n", - " x[\"explained_peaks\"],\n", - " x[\"peaks\"],\n", + " x['explained_peaks'],\n", + " x['peaks'],\n", " transform=np.sqrt,\n", - " remove_mz=x[\"Metabolite\"].get_theoretical_precursor_mz(),\n", + " remove_mz=x['Metabolite'].get_theoretical_precursor_mz(),\n", " ),\n", " axis=1,\n", ")" @@ -2070,26 +2065,26 @@ "# Dataframe with Val, Test, CASMI-16 rows and fiora model median scores w/o precursor\n", "df_scores = pd.DataFrame(\n", " {\n", - " \"dataset\": [\"Validation\", \"Test\", \"CASMI-16\"],\n", - " \"spectral_sqrt_cosine\": [\n", - " avg_func(df_val[\"spectral_sqrt_cosine\"]),\n", - " avg_func(df_test[\"spectral_sqrt_cosine\"]),\n", - " avg_func(df_cas16[\"spectral_sqrt_cosine\"].fillna(0)),\n", + " 'dataset': ['Validation', 'Test', 'CASMI-16'],\n", + " 'spectral_sqrt_cosine': [\n", + " avg_func(df_val['spectral_sqrt_cosine']),\n", + " avg_func(df_test['spectral_sqrt_cosine']),\n", + " avg_func(df_cas16['spectral_sqrt_cosine'].fillna(0)),\n", " ],\n", - " \"explained_sqrt_cosine\": [\n", - " avg_func(df_val[\"explained_sqrt_cosine\"]),\n", - " avg_func(df_test[\"explained_sqrt_cosine\"]),\n", - " avg_func(df_cas16[\"explained_sqrt_cosine\"].fillna(0)),\n", + " 'explained_sqrt_cosine': [\n", + " avg_func(df_val['explained_sqrt_cosine']),\n", + " avg_func(df_test['explained_sqrt_cosine']),\n", + " avg_func(df_cas16['explained_sqrt_cosine'].fillna(0)),\n", " ],\n", - " \"spectral_sqrt_cosine_wo_prec\": [\n", - " avg_func(df_val[\"spectral_sqrt_cosine_wo_prec\"]),\n", - " avg_func(df_test[\"spectral_sqrt_cosine_wo_prec\"]),\n", - " avg_func(df_cas16[\"spectral_sqrt_cosine_wo_prec\"].fillna(0)),\n", + " 'spectral_sqrt_cosine_wo_prec': [\n", + " avg_func(df_val['spectral_sqrt_cosine_wo_prec']),\n", + " avg_func(df_test['spectral_sqrt_cosine_wo_prec']),\n", + " avg_func(df_cas16['spectral_sqrt_cosine_wo_prec'].fillna(0)),\n", " ],\n", - " \"explained_sqrt_cosine_wo_prec\": [\n", - " avg_func(df_val[\"explained_sqrt_cosine_wo_prec\"]),\n", - " avg_func(df_test[\"explained_sqrt_cosine_wo_prec\"]),\n", - " avg_func(df_cas16[\"explained_sqrt_cosine_wo_prec\"].fillna(0)),\n", + " 'explained_sqrt_cosine_wo_prec': [\n", + " avg_func(df_val['explained_sqrt_cosine_wo_prec']),\n", + " avg_func(df_test['explained_sqrt_cosine_wo_prec']),\n", + " avg_func(df_cas16['explained_sqrt_cosine_wo_prec'].fillna(0)),\n", " ],\n", " }\n", ")\n", @@ -2114,20 +2109,21 @@ } ], "source": [ - "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", "from fiora.visualization.define_colors import set_light_theme\n", "\n", "set_light_theme()\n", "\n", "# Define custom colors for the precursor types\n", "custom_palette = {\n", - " \"[M+H]+\": \"indianred\", # Nice red tint\n", - " \"[M]+\": \"salmon\", # Nice orange tint\n", - " \"[M-H]-\": \"dodgerblue\", # Nice blue tint\n", - " \"[M]-\": \"lightblue\", # Different blue tint\n", + " '[M+H]+': 'indianred', # Nice red tint\n", + " '[M]+': 'salmon', # Nice orange tint\n", + " '[M-H]-': 'dodgerblue', # Nice blue tint\n", + " '[M]-': 'lightblue', # Different blue tint\n", "}\n", - "precursor_types = [\"[M+H]+\", \"[M]+\", \"[M-H]-\", \"[M]-\"]\n", + "precursor_types = ['[M+H]+', '[M]+', '[M-H]-', '[M]-']\n", "\n", "# Create a figure with two subplots\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharey=True)\n", @@ -2135,63 +2131,63 @@ "# Left subplot: spectral_sqrt_cosine vs spectral_sqrt_cosine with Precursor_type as hue\n", "sns.boxplot(\n", " data=df_val,\n", - " x=\"Precursor_type\",\n", - " y=\"spectral_sqrt_cosine\",\n", + " x='Precursor_type',\n", + " y='spectral_sqrt_cosine',\n", " ax=axes[0],\n", " palette=custom_palette,\n", - " hue=\"Precursor_type\",\n", + " hue='Precursor_type',\n", " linewidth=2,\n", " legend=False,\n", " order=precursor_types,\n", ")\n", - "axes[0].set_title(\"Spectral Cosine\")\n", - "axes[0].set_ylabel(\"Cosine Similarity\")\n", - "axes[0].set_xlabel(\"Precursor Type\")\n", + "axes[0].set_title('Spectral Cosine')\n", + "axes[0].set_ylabel('Cosine Similarity')\n", + "axes[0].set_xlabel('Precursor Type')\n", "\n", "# Add medians above the median line for the left subplot\n", "for i, precursor_type in enumerate(precursor_types):\n", - " median = df_val[df_val[\"Precursor_type\"] == precursor_type][\n", - " \"spectral_sqrt_cosine\"\n", + " median = df_val[df_val['Precursor_type'] == precursor_type][\n", + " 'spectral_sqrt_cosine'\n", " ].median()\n", " axes[0].text(\n", " i,\n", " median + 0.01,\n", - " f\"{median:.2f}\",\n", - " ha=\"center\",\n", - " va=\"bottom\",\n", + " f'{median:.2f}',\n", + " ha='center',\n", + " va='bottom',\n", " fontsize=11,\n", - " color=\"black\",\n", + " color='black',\n", " )\n", "\n", "# Right subplot: spectral_sqrt_cosine_wo_prec vs spectral_sqrt_cosine_wo_prec with Precursor_type as hue\n", "sns.boxplot(\n", " data=df_val,\n", - " x=\"Precursor_type\",\n", - " y=\"spectral_sqrt_cosine_wo_prec\",\n", + " x='Precursor_type',\n", + " y='spectral_sqrt_cosine_wo_prec',\n", " ax=axes[1],\n", - " hue=\"Precursor_type\",\n", + " hue='Precursor_type',\n", " palette=custom_palette,\n", " linewidth=2,\n", " legend=True,\n", " order=precursor_types,\n", ")\n", - "axes[1].set_title(\"Spectral Cosine (wo Prec)\")\n", - "axes[1].set_ylabel(\"Cosine Similarity\")\n", - "axes[1].set_xlabel(\"Precursor Type\")\n", + "axes[1].set_title('Spectral Cosine (wo Prec)')\n", + "axes[1].set_ylabel('Cosine Similarity')\n", + "axes[1].set_xlabel('Precursor Type')\n", "\n", "# Add medians above the median line for the right subplot\n", "for i, precursor_type in enumerate(precursor_types):\n", - " median = df_val[df_val[\"Precursor_type\"] == precursor_type][\n", - " \"spectral_sqrt_cosine_wo_prec\"\n", + " median = df_val[df_val['Precursor_type'] == precursor_type][\n", + " 'spectral_sqrt_cosine_wo_prec'\n", " ].median()\n", " axes[1].text(\n", " i,\n", " median + 0.01,\n", - " f\"{median:.2f}\",\n", - " ha=\"center\",\n", - " va=\"bottom\",\n", + " f'{median:.2f}',\n", + " ha='center',\n", + " va='bottom',\n", " fontsize=11,\n", - " color=\"black\",\n", + " color='black',\n", " )\n", "\n", "# Adjust layout and show the plot\n", @@ -2221,63 +2217,63 @@ "# Left subplot: spectral_sqrt_cosine vs spectral_sqrt_cosine with Precursor_type as hue\n", "sns.boxplot(\n", " data=df_test,\n", - " x=\"Precursor_type\",\n", - " y=\"spectral_sqrt_cosine\",\n", + " x='Precursor_type',\n", + " y='spectral_sqrt_cosine',\n", " ax=axes[0],\n", " palette=custom_palette,\n", - " hue=\"Precursor_type\",\n", + " hue='Precursor_type',\n", " linewidth=2,\n", " legend=False,\n", " order=precursor_types,\n", ")\n", - "axes[0].set_title(\"Spectral Cosine\")\n", - "axes[0].set_ylabel(\"Cosine Similarity\")\n", - "axes[0].set_xlabel(\"Precursor Type\")\n", + "axes[0].set_title('Spectral Cosine')\n", + "axes[0].set_ylabel('Cosine Similarity')\n", + "axes[0].set_xlabel('Precursor Type')\n", "\n", "# Add medians above the median line for the left subplot\n", "for i, precursor_type in enumerate(precursor_types):\n", - " median = df_test[df_test[\"Precursor_type\"] == precursor_type][\n", - " \"spectral_sqrt_cosine\"\n", + " median = df_test[df_test['Precursor_type'] == precursor_type][\n", + " 'spectral_sqrt_cosine'\n", " ].median()\n", " axes[0].text(\n", " i,\n", " median + 0.01,\n", - " f\"{median:.2f}\",\n", - " ha=\"center\",\n", - " va=\"bottom\",\n", + " f'{median:.2f}',\n", + " ha='center',\n", + " va='bottom',\n", " fontsize=11,\n", - " color=\"black\",\n", + " color='black',\n", " )\n", "\n", "# Right subplot: spectral_sqrt_cosine_wo_prec vs spectral_sqrt_cosine_wo_prec with Precursor_type as hue\n", "sns.boxplot(\n", " data=df_test,\n", - " x=\"Precursor_type\",\n", - " y=\"spectral_sqrt_cosine_wo_prec\",\n", + " x='Precursor_type',\n", + " y='spectral_sqrt_cosine_wo_prec',\n", " ax=axes[1],\n", - " hue=\"Precursor_type\",\n", + " hue='Precursor_type',\n", " palette=custom_palette,\n", " linewidth=2,\n", " legend=True,\n", " order=precursor_types,\n", ")\n", - "axes[1].set_title(\"Spectral Cosine (wo Prec)\")\n", - "axes[1].set_ylabel(\"Cosine Similarity\")\n", - "axes[1].set_xlabel(\"Precursor Type\")\n", + "axes[1].set_title('Spectral Cosine (wo Prec)')\n", + "axes[1].set_ylabel('Cosine Similarity')\n", + "axes[1].set_xlabel('Precursor Type')\n", "\n", "# Add medians above the median line for the right subplot\n", "for i, precursor_type in enumerate(precursor_types):\n", - " median = df_test[df_test[\"Precursor_type\"] == precursor_type][\n", - " \"spectral_sqrt_cosine_wo_prec\"\n", + " median = df_test[df_test['Precursor_type'] == precursor_type][\n", + " 'spectral_sqrt_cosine_wo_prec'\n", " ].median()\n", " axes[1].text(\n", " i,\n", " median + 0.01,\n", - " f\"{median:.2f}\",\n", - " ha=\"center\",\n", - " va=\"bottom\",\n", + " f'{median:.2f}',\n", + " ha='center',\n", + " va='bottom',\n", " fontsize=11,\n", - " color=\"black\",\n", + " color='black',\n", " )\n", "\n", "# Adjust layout and show the plot\n", @@ -2307,61 +2303,61 @@ "# Left subplot: spectral_sqrt_cosine vs spectral_sqrt_cosine with Precursor_type as hue\n", "sns.boxplot(\n", " data=df_cas16,\n", - " x=\"Precursor_type\",\n", - " y=\"spectral_sqrt_cosine\",\n", + " x='Precursor_type',\n", + " y='spectral_sqrt_cosine',\n", " ax=axes[0],\n", " palette=custom_palette,\n", - " hue=\"Precursor_type\",\n", + " hue='Precursor_type',\n", " linewidth=2,\n", " legend=False,\n", ")\n", - "axes[0].set_title(\"Spectral Cosine\")\n", - "axes[0].set_ylabel(\"Cosine Similarity\")\n", - "axes[0].set_xlabel(\"Precursor Type\")\n", + "axes[0].set_title('Spectral Cosine')\n", + "axes[0].set_ylabel('Cosine Similarity')\n", + "axes[0].set_xlabel('Precursor Type')\n", "\n", "# Add medians above the median line for the left subplot\n", - "for i, precursor_type in enumerate(df_cas16[\"Precursor_type\"].unique()):\n", - " median = df_cas16[df_cas16[\"Precursor_type\"] == precursor_type][\n", - " \"spectral_sqrt_cosine\"\n", + "for i, precursor_type in enumerate(df_cas16['Precursor_type'].unique()):\n", + " median = df_cas16[df_cas16['Precursor_type'] == precursor_type][\n", + " 'spectral_sqrt_cosine'\n", " ].median()\n", " axes[0].text(\n", " i,\n", " median + 0.01,\n", - " f\"{median:.2f}\",\n", - " ha=\"center\",\n", - " va=\"bottom\",\n", + " f'{median:.2f}',\n", + " ha='center',\n", + " va='bottom',\n", " fontsize=11,\n", - " color=\"black\",\n", + " color='black',\n", " )\n", "\n", "# Right subplot: spectral_sqrt_cosine_wo_prec vs spectral_sqrt_cosine_wo_prec with Precursor_type as hue\n", "sns.boxplot(\n", " data=df_cas16,\n", - " x=\"Precursor_type\",\n", - " y=\"spectral_sqrt_cosine_wo_prec\",\n", + " x='Precursor_type',\n", + " y='spectral_sqrt_cosine_wo_prec',\n", " ax=axes[1],\n", - " hue=\"Precursor_type\",\n", + " hue='Precursor_type',\n", " palette=custom_palette,\n", " linewidth=2,\n", " legend=True,\n", ")\n", - "axes[1].set_title(\"Spectral Cosine (wo Prec)\")\n", - "axes[1].set_ylabel(\"Cosine Similarity\")\n", - "axes[1].set_xlabel(\"Precursor Type\")\n", + "axes[1].set_title('Spectral Cosine (wo Prec)')\n", + "axes[1].set_ylabel('Cosine Similarity')\n", + "axes[1].set_xlabel('Precursor Type')\n", "\n", "# Add medians above the median line for the right subplot\n", - "for i, precursor_type in enumerate(df_cas16[\"Precursor_type\"].unique()):\n", - " median = df_cas16[df_cas16[\"Precursor_type\"] == precursor_type][\n", - " \"spectral_sqrt_cosine_wo_prec\"\n", + "for i, precursor_type in enumerate(df_cas16['Precursor_type'].unique()):\n", + " median = df_cas16[df_cas16['Precursor_type'] == precursor_type][\n", + " 'spectral_sqrt_cosine_wo_prec'\n", " ].median()\n", " axes[1].text(\n", " i,\n", " median + 0.01,\n", - " f\"{median:.2f}\",\n", - " ha=\"center\",\n", - " va=\"bottom\",\n", + " f'{median:.2f}',\n", + " ha='center',\n", + " va='bottom',\n", " fontsize=11,\n", - " color=\"black\",\n", + " color='black',\n", " )\n", "\n", "# Adjust layout and show the plot\n", @@ -2387,22 +2383,22 @@ ], "source": [ "# Extract exact molecular weight from Metabolite and store it in a new column\n", - "df_val[\"exact_mol_weight\"] = df_val[\"Metabolite\"].apply(lambda x: x.ExactMolWeight)\n", + "df_val['exact_mol_weight'] = df_val['Metabolite'].apply(lambda x: x.ExactMolWeight)\n", "\n", "# Divide the molecular weight into bins of 100 from 100 to 1000\n", "bins = list(range(100, 1100, 100))\n", - "df_val[\"mol_weight_bin\"] = pd.cut(df_val[\"exact_mol_weight\"], bins=bins)\n", + "df_val['mol_weight_bin'] = pd.cut(df_val['exact_mol_weight'], bins=bins)\n", "\n", "# Count the number of entries in each bin\n", - "bin_counts = df_val[\"mol_weight_bin\"].value_counts(sort=False)\n", + "bin_counts = df_val['mol_weight_bin'].value_counts(sort=False)\n", "\n", "# Create the figure and subplots\n", "fig, axes = plt.subplots(\n", - " 2, 1, figsize=(12, 8), gridspec_kw={\"height_ratios\": [1, 3]}, sharex=True\n", + " 2, 1, figsize=(12, 8), gridspec_kw={'height_ratios': [1, 3]}, sharex=True\n", ")\n", "\n", "# Generate the viridis color palette\n", - "viridis_palette = sns.color_palette(\"viridis\", len(bin_counts))\n", + "viridis_palette = sns.color_palette('viridis', len(bin_counts))\n", "\n", "# Top subplot: Bar plot for counts in each bin\n", "sns.barplot(\n", @@ -2414,25 +2410,25 @@ " legend=False,\n", " dodge=False,\n", ")\n", - "axes[0].set_title(\"Counts per Molecular Weight Bin\")\n", - "axes[0].set_ylabel(\"Count\")\n", - "axes[0].set_xlabel(\"\")\n", - "axes[0].tick_params(axis=\"x\", rotation=45)\n", + "axes[0].set_title('Counts per Molecular Weight Bin')\n", + "axes[0].set_ylabel('Count')\n", + "axes[0].set_xlabel('')\n", + "axes[0].tick_params(axis='x', rotation=45)\n", "\n", "# Bottom subplot: Boxplot for spectral sqrt cosine values for each bin\n", "sns.boxplot(\n", " data=df_val,\n", - " x=\"mol_weight_bin\",\n", - " y=\"spectral_sqrt_cosine\",\n", + " x='mol_weight_bin',\n", + " y='spectral_sqrt_cosine',\n", " ax=axes[1],\n", - " palette=\"viridis\",\n", - " hue=\"mol_weight_bin\",\n", + " palette='viridis',\n", + " hue='mol_weight_bin',\n", " legend=False,\n", ")\n", - "axes[1].set_title(\"Spectral Sqrt Cosine vs Molecular Weight\")\n", - "axes[1].set_ylabel(\"Spectral Sqrt Cosine\")\n", - "axes[1].set_xlabel(\"Molecular Weight Bin\")\n", - "axes[1].tick_params(axis=\"x\", rotation=45)\n", + "axes[1].set_title('Spectral Sqrt Cosine vs Molecular Weight')\n", + "axes[1].set_ylabel('Spectral Sqrt Cosine')\n", + "axes[1].set_xlabel('Molecular Weight Bin')\n", + "axes[1].tick_params(axis='x', rotation=45)\n", "\n", "# Adjust layout and show the plot\n", "plt.tight_layout()\n", @@ -2492,7 +2488,7 @@ } ], "source": [ - "raise KeyboardInterrupt(\"Halt! Make sure you wish to save/overwrite model files\")" + "raise KeyboardInterrupt('Halt! Make sure you wish to save/overwrite model files')" ] }, { @@ -2501,15 +2497,15 @@ "metadata": {}, "outputs": [], "source": [ - "model.model_params[\"version\"] = \"FIORA OS v1.0.0\"\n", - "model.model_params[\"version_number\"] = \"1.0.0\"\n", + "model.model_params['version'] = 'FIORA OS v1.0.0'\n", + "model.model_params['version_number'] = '1.0.0'\n", "\n", - "model.model_params[\"training_library\"] = \"MSnLib v7\"\n", - "model.model_params[\"comment\"] = (\n", - " \"This is an open-source FIORA model released on GitHub trained on the MSnLib v7.\"\n", + "model.model_params['training_library'] = 'MSnLib v7'\n", + "model.model_params['comment'] = (\n", + " 'This is an open-source FIORA model released on GitHub trained on the MSnLib v7.'\n", ")\n", - "model.model_params[\"disclaimer\"] = (\n", - " \"No prediction software is perfect. Use with caution.\"\n", + "model.model_params['disclaimer'] = (\n", + " 'No prediction software is perfect. Use with caution.'\n", ")" ] }, @@ -2530,13 +2526,13 @@ "source": [ "save_model = True\n", "if save_model:\n", - " depth = model.model_params[\"depth\"]\n", - " print(f\"Saving model with depth {depth}\")\n", + " depth = model.model_params['depth']\n", + " print(f'Saving model with depth {depth}')\n", " MODEL_PATH = (\n", - " f\"{home}/data/metabolites/pretrained_models/v1.0.0_OS_depth{depth}_Sep25_x4.pt\"\n", + " f'{home}/data/metabolites/pretrained_models/v1.0.0_OS_depth{depth}_Sep25_x4.pt'\n", " )\n", " model.save(MODEL_PATH)\n", - " print(f\"Saved to {MODEL_PATH}\")" + " print(f'Saved to {MODEL_PATH}')" ] }, { @@ -2555,14 +2551,14 @@ "source": [ "# Load checkpoints depth_1..13 (mapping to depths 0..12), predict on validation, collect per-depth predictions\n", "import os\n", - "from fiora.GNN.FioraModel import FioraModel\n", "\n", + "from fiora.GNN.FioraModel import FioraModel\n", "\n", "pred_rows = []\n", "missing = []\n", - "print(\"Benchmarking depth parameter on validation set. This may take a while...\")\n", + "print('Benchmarking depth parameter on validation set. This may take a while...')\n", "for i in range(1, 14): # files: checkpoint_depth_1..13.best.pt\n", - " ckpt_path = f\"../../checkpoint_depth_{i}.best.pt\"\n", + " ckpt_path = f'../../checkpoint_depth_{i}.best.pt'\n", " if not os.path.exists(ckpt_path):\n", " missing.append(ckpt_path)\n", " continue\n", @@ -2573,9 +2569,9 @@ " ) # adds spectral_sqrt_cosine\n", " depth = i - 1 # file index -> model depth\n", " tmp = dfi[\n", - " [\"group_id\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\"]\n", + " ['group_id', 'spectral_sqrt_cosine', 'spectral_sqrt_cosine_wo_prec']\n", " ].copy()\n", - " tmp[\"depth\"] = depth\n", + " tmp['depth'] = depth\n", " pred_rows.append(tmp)\n", " del model_i\n", " if torch.cuda.is_available():\n", @@ -2601,41 +2597,41 @@ } ], "source": [ - "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", "\n", "# Build long-form data from raw preds_long\n", - "plot_long = preds_long.assign(depth=preds_long[\"depth\"].astype(int)).melt(\n", - " id_vars=[\"group_id\", \"depth\"],\n", - " value_vars=[\"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\"],\n", - " var_name=\"metric\",\n", - " value_name=\"score\",\n", + "plot_long = preds_long.assign(depth=preds_long['depth'].astype(int)).melt(\n", + " id_vars=['group_id', 'depth'],\n", + " value_vars=['spectral_sqrt_cosine', 'spectral_sqrt_cosine_wo_prec'],\n", + " var_name='metric',\n", + " value_name='score',\n", ")\n", - "plot_long[\"metric\"] = plot_long[\"metric\"].map(\n", + "plot_long['metric'] = plot_long['metric'].map(\n", " {\n", - " \"spectral_sqrt_cosine\": \"Cosine Similarity\",\n", - " \"spectral_sqrt_cosine_wo_prec\": \"Cosine Similarity w/o Precursor\",\n", + " 'spectral_sqrt_cosine': 'Cosine Similarity',\n", + " 'spectral_sqrt_cosine_wo_prec': 'Cosine Similarity w/o Precursor',\n", " }\n", ")\n", "\n", "plt.figure(figsize=(10, 5))\n", "sns.pointplot(\n", " data=plot_long,\n", - " x=\"depth\",\n", - " y=\"score\",\n", - " hue=\"metric\",\n", + " x='depth',\n", + " y='score',\n", + " hue='metric',\n", " estimator=np.median,\n", - " errorbar=\"sd\",\n", + " errorbar='sd',\n", " n_boot=1000,\n", - " markers=\"o\",\n", + " markers='o',\n", " dodge=0.2,\n", " capsize=0.2,\n", ")\n", - "plt.xticks(sorted(plot_long[\"depth\"].unique().tolist()))\n", - "plt.xlabel(\"Depth\")\n", - "plt.ylabel(\"Median spectral sqrt cosine\")\n", + "plt.xticks(sorted(plot_long['depth'].unique().tolist()))\n", + "plt.xlabel('Depth')\n", + "plt.ylabel('Median spectral sqrt cosine')\n", "plt.ylim(0.35, 1.05)\n", - "plt.title(\"Validation spectral sqrt cosine over depth (95% CI bars)\")\n", + "plt.title('Validation spectral sqrt cosine over depth (95% CI bars)')\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.show()" @@ -2688,7 +2684,7 @@ } ], "source": [ - "df_test[\"Metabolite\"].iloc[0].as_geometric_data()" + "df_test['Metabolite'].iloc[0].as_geometric_data()" ] }, { @@ -2697,7 +2693,7 @@ "metadata": {}, "outputs": [], "source": [ - "dev = \"cuda:0\"\n", + "dev = 'cuda:0'\n", "mymy = GNNCompiler.load(\n", " MODEL_PATH\n", ") # f\"{home}/data/metabolites/pretrained_models/v0.0.1_merged_2.pt\"\n", @@ -2754,10 +2750,10 @@ "source": [ "import json\n", "\n", - "with open(MODEL_PATH.replace(\".pt\", \"_params.json\"), \"r\") as fp:\n", + "with open(MODEL_PATH.replace('.pt', '_params.json'), 'r') as fp:\n", " p = json.load(fp)\n", "hh = GNNCompiler(p)\n", - "hh.load_state_dict(torch.load(MODEL_PATH.replace(\".pt\", \"_state.pt\")))\n", + "hh.load_state_dict(torch.load(MODEL_PATH.replace('.pt', '_state.pt')))\n", "hh.eval()\n", "hh = hh.to(dev)" ] @@ -2800,7 +2796,7 @@ } ], "source": [ - "raise KeyboardInterrupt(\"TODO\")" + "raise KeyboardInterrupt('TODO')" ] }, { @@ -2833,25 +2829,25 @@ "import os\n", "\n", "save_df = False\n", - "cfm_directory = f\"{home}/data/metabolites/cfm-id/\"\n", - "name = \"test_split_negative_solutions_cfm.txt\"\n", - "df_cfm = df_test[[\"group_id\", \"SMILES\", \"Precursor_type\"]]\n", - "df_n = df_cfm[df_cfm[\"Precursor_type\"] == \"[M-H]-\"].drop_duplicates(\n", - " subset=\"group_id\", keep=\"first\"\n", + "cfm_directory = f'{home}/data/metabolites/cfm-id/'\n", + "name = 'test_split_negative_solutions_cfm.txt'\n", + "df_cfm = df_test[['group_id', 'SMILES', 'Precursor_type']]\n", + "df_n = df_cfm[df_cfm['Precursor_type'] == '[M-H]-'].drop_duplicates(\n", + " subset='group_id', keep='first'\n", ")\n", - "df_p = df_cfm[df_cfm[\"Precursor_type\"] == \"[M+H]+\"].drop_duplicates(\n", - " subset=\"group_id\", keep=\"first\"\n", + "df_p = df_cfm[df_cfm['Precursor_type'] == '[M+H]+'].drop_duplicates(\n", + " subset='group_id', keep='first'\n", ")\n", "\n", "print(df_n.head())\n", "\n", "if save_df:\n", " file = os.path.join(cfm_directory, name)\n", - " df_n[[\"group_id\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")\n", + " df_n[['group_id', 'SMILES']].to_csv(file, index=False, header=False, sep=' ')\n", "\n", - " name = name.replace(\"negative\", \"positive\")\n", + " name = name.replace('negative', 'positive')\n", " file = os.path.join(cfm_directory, name)\n", - " df_p[[\"group_id\", \"SMILES\"]].to_csv(file, index=False, header=False, sep=\" \")" + " df_p[['group_id', 'SMILES']].to_csv(file, index=False, header=False, sep=' ')" ] } ], diff --git a/resources/data/msnlib/download_msnlib.py b/resources/data/msnlib/download_msnlib.py index ac9feba..dadbb39 100644 --- a/resources/data/msnlib/download_msnlib.py +++ b/resources/data/msnlib/download_msnlib.py @@ -26,18 +26,17 @@ from urllib.parse import urlparse from urllib.request import urlopen, urlretrieve - -DEFAULT_URL = "https://zenodo.org/records/16984129" -DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent / "raw" +DEFAULT_URL = 'https://zenodo.org/records/16984129' +DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent / 'raw' def _default_filename(url: str) -> str: name = Path(urlparse(url).path).name - return name or "msnlib_download" + return name or 'msnlib_download' def _parse_zenodo_record_id(url: str) -> str | None: - m = re.search(r"zenodo\.org/(?:records|record)/(\d+)", url) + m = re.search(r'zenodo\.org/(?:records|record)/(\d+)', url) if not m: return None return m.group(1) @@ -51,126 +50,126 @@ def _resolve_zenodo_record( record_max_files: int, ) -> list[tuple[str, str]]: record_id = _parse_zenodo_record_id(url) - if record_id is None or "/files/" in urlparse(url).path: + if record_id is None or '/files/' in urlparse(url).path: resolved_url = url filename = filename_override or _default_filename(resolved_url) return [(resolved_url, filename)] - api_url = f"https://zenodo.org/api/records/{record_id}" + api_url = f'https://zenodo.org/api/records/{record_id}' with urlopen(api_url) as resp: - payload = json.loads(resp.read().decode("utf-8")) + payload = json.loads(resp.read().decode('utf-8')) - files = payload.get("files", []) + files = payload.get('files', []) if not files: - raise RuntimeError(f"No downloadable files found in Zenodo record {record_id}") + raise RuntimeError(f'No downloadable files found in Zenodo record {record_id}') selected = [] for file_item in files: - key = str(file_item.get("key", "")) + key = str(file_item.get('key', '')) if not record_pattern or fnmatch.fnmatch(key, record_pattern): selected.append(file_item) if not selected: raise RuntimeError( - f"No files in Zenodo record {record_id} match pattern {record_pattern!r}" + f'No files in Zenodo record {record_id} match pattern {record_pattern!r}' ) - selected = sorted(selected, key=lambda x: str(x.get("key", ""))) + selected = sorted(selected, key=lambda x: str(x.get('key', ''))) if record_max_files > 0: selected = selected[:record_max_files] if filename_override is not None and len(selected) != 1: raise RuntimeError( - "--filename can only be used when exactly one file is selected" + '--filename can only be used when exactly one file is selected' ) resolved = [] for file_item in selected: - links = file_item.get("links", {}) - resolved_url = links.get("self") + links = file_item.get('links', {}) + resolved_url = links.get('self') if not resolved_url: continue - key = str(file_item.get("key") or _default_filename(resolved_url)) + key = str(file_item.get('key') or _default_filename(resolved_url)) filename = filename_override or key resolved.append((resolved_url, filename)) if not resolved: raise RuntimeError( - f"Could not resolve any download URLs for record {record_id}" + f'Could not resolve any download URLs for record {record_id}' ) return resolved def _extract_archive(path: Path, output_dir: Path) -> None: lower = path.name.lower() - if lower.endswith(".zip"): - with zipfile.ZipFile(path, "r") as zf: + if lower.endswith('.zip'): + with zipfile.ZipFile(path, 'r') as zf: zf.extractall(output_dir) return - if lower.endswith(".tar.gz") or lower.endswith(".tgz"): - with tarfile.open(path, "r:gz") as tf: + if lower.endswith('.tar.gz') or lower.endswith('.tgz'): + with tarfile.open(path, 'r:gz') as tf: tf.extractall(output_dir) return - if lower.endswith(".gz") and not lower.endswith(".tar.gz"): - out_path = output_dir / path.with_suffix("").name - with gzip.open(path, "rb") as src, open(out_path, "wb") as dst: + if lower.endswith('.gz') and not lower.endswith('.tar.gz'): + out_path = output_dir / path.with_suffix('').name + with gzip.open(path, 'rb') as src, open(out_path, 'wb') as dst: shutil.copyfileobj(src, dst) return - raise ValueError(f"Unsupported archive format: {path}") + raise ValueError(f'Unsupported archive format: {path}') def _is_archive(path: Path) -> bool: lower = path.name.lower() return ( - lower.endswith(".zip") - or lower.endswith(".tar.gz") - or lower.endswith(".tgz") - or (lower.endswith(".gz") and not lower.endswith(".tar.gz")) + lower.endswith('.zip') + or lower.endswith('.tar.gz') + or lower.endswith('.tgz') + or (lower.endswith('.gz') and not lower.endswith('.tar.gz')) ) def main() -> None: - parser = argparse.ArgumentParser(description="Download MSnLib files.") + parser = argparse.ArgumentParser(description='Download MSnLib files.') parser.add_argument( - "--url", + '--url', default=DEFAULT_URL, help=( - "URL to download. Zenodo record URLs are supported and resolved to a file " - "(default: MSnLib v7 Zenodo record)." + 'URL to download. Zenodo record URLs are supported and resolved to a file ' + '(default: MSnLib v7 Zenodo record).' ), ) parser.add_argument( - "--output-dir", + '--output-dir', default=str(DEFAULT_OUTPUT_DIR), - help="Directory to store downloads/extracted files.", + help='Directory to store downloads/extracted files.', ) parser.add_argument( - "--filename", + '--filename', default=None, - help="Optional filename override for the downloaded file.", + help='Optional filename override for the downloaded file.', ) parser.add_argument( - "--record-pattern", - default="*_ms2.mgf", + '--record-pattern', + default='*_ms2.mgf', help=( - "Glob pattern for file keys when --url is a Zenodo record " - "(default: *_ms2.mgf)." + 'Glob pattern for file keys when --url is a Zenodo record ' + '(default: *_ms2.mgf).' ), ) parser.add_argument( - "--record-max-files", + '--record-max-files', type=int, default=0, help=( - "Limit number of selected files from a Zenodo record. " - "0 means no limit (default)." + 'Limit number of selected files from a Zenodo record. ' + '0 means no limit (default).' ), ) parser.add_argument( - "--extract", + '--extract', action=argparse.BooleanOptionalAction, default=True, - help="Extract archives after download (default: true).", + help='Extract archives after download (default: true).', ) args = parser.parse_args() @@ -183,11 +182,11 @@ def main() -> None: record_pattern=args.record_pattern, record_max_files=args.record_max_files, ) - print(f"Selected {len(resolved_downloads)} file(s) from {args.url}") + print(f'Selected {len(resolved_downloads)} file(s) from {args.url}') for resolved_url, filename in resolved_downloads: dest = output_dir / filename - print(f"Downloading {resolved_url} -> {dest}") + print(f'Downloading {resolved_url} -> {dest}') urlretrieve(resolved_url, dest) if args.extract: @@ -195,10 +194,10 @@ def main() -> None: extract_dir = output_dir / dest.stem extract_dir.mkdir(parents=True, exist_ok=True) _extract_archive(dest, extract_dir) - print(f"Extracted to {extract_dir}") + print(f'Extracted to {extract_dir}') else: - print(f"No extraction performed for {dest.name} (not an archive).") + print(f'No extraction performed for {dest.name} (not an archive).') -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/resources/data/msnlib/preprocess_msnlib.py b/resources/data/msnlib/preprocess_msnlib.py index 76d95bb..eae15a1 100644 --- a/resources/data/msnlib/preprocess_msnlib.py +++ b/resources/data/msnlib/preprocess_msnlib.py @@ -13,24 +13,24 @@ from rdkit.Chem import Descriptors from sklearn.model_selection import train_test_split +from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder from fiora.IO import mgfReader from fiora.IO.LibraryLoader import LibraryLoader from fiora.MOL import constants as mol_constants from fiora.MOL.Metabolite import Metabolite from fiora.MOL.MetaboliteIndex import MetaboliteIndex -from fiora.GNN.CovariateFeatureEncoder import CovariateFeatureEncoder -RDLogger.DisableLog("rdApp.*") +RDLogger.DisableLog('rdApp.*') BASE_DIR = Path(__file__).resolve().parent -DEFAULT_OUTPUT = BASE_DIR / "library.csv" -DEFAULT_ALLOWED_PRECURSOR_MODES = ["[M+H]+", "[M-H]-", "[M]+", "[M]-"] -DEFAULT_SPECTYPES = ["SINGLE_BEST_SCAN", "SAME_ENERGY", "SINGLE_SCAN"] +DEFAULT_OUTPUT = BASE_DIR / 'library.csv' +DEFAULT_ALLOWED_PRECURSOR_MODES = ['[M+H]+', '[M-H]-', '[M]+', '[M]-'] +DEFAULT_SPECTYPES = ['SINGLE_BEST_SCAN', 'SAME_ENERGY', 'SINGLE_SCAN'] def _log(msg: str, verbose: bool) -> None: if verbose: - print(f"[preprocess_msnlib] {msg}", flush=True) + print(f'[preprocess_msnlib] {msg}', flush=True) def _iter_progress(iterable, *, total: int, desc: str, enabled: bool): @@ -46,16 +46,16 @@ def _iter_progress(iterable, *, total: int, desc: str, enabled: bool): def _load_msnlib_dir(path: Path) -> pd.DataFrame: dfs = [] for filename in sorted(path.iterdir()): - if not filename.name.endswith("ms2.mgf"): + if not filename.name.endswith('ms2.mgf'): continue df = pd.DataFrame(mgfReader.read(str(filename))) - df["file"] = filename.name - df["lib"] = "MSnLib" - parts = filename.name.split("_") - df["origin"] = parts[1] if len(parts) > 1 else "" + df['file'] = filename.name + df['lib'] = 'MSnLib' + parts = filename.name.split('_') + df['origin'] = parts[1] if len(parts) > 1 else '' dfs.append(df) if not dfs: - raise SystemExit(f"No ms2.mgf files found in {path}") + raise SystemExit(f'No ms2.mgf files found in {path}') df = pd.concat(dfs, ignore_index=True) df.reset_index(inplace=True) return df @@ -66,8 +66,8 @@ def _parse(val): if pd.isna(val): return [] text = str(val).strip() - if "[" in text and "]" in text: - text = text.strip("[]") + if '[' in text and ']' in text: + text = text.strip('[]') parts = [p for p in text.split(delim) if p] vals = [] for p in parts: @@ -87,41 +87,41 @@ def _parse(val): def _reweight_groups(df: pd.DataFrame) -> pd.DataFrame: - df["num_per_group"] = df["group_id"].map(df["group_id"].value_counts()) - df["loss_weight"] = 1.0 / df["num_per_group"] + df['num_per_group'] = df['group_id'].map(df['group_id'].value_counts()) + df['loss_weight'] = 1.0 / df['num_per_group'] return df def _apply_hard_soft_filters(df: pd.DataFrame) -> pd.DataFrame: - hard_filters = {"min_peaks": 2, "min_coverage": 0.5, "max_precursor_intensity": 0.9} + hard_filters = {'min_peaks': 2, 'min_coverage': 0.5, 'max_precursor_intensity': 0.9} soft_filters = { - "desired_peaks": 4, - "desired_coverage": 0.75, - "desired_peak_percentage": 0.5, + 'desired_peaks': 4, + 'desired_coverage': 0.75, + 'desired_peak_percentage': 0.5, } drop_indices = [] for i, data in df.iterrows(): - m = data["Metabolite"] + m = data['Metabolite'] hard_pass = True - if m.match_stats["num_peak_matches_filtered"] < hard_filters["min_peaks"]: + if m.match_stats['num_peak_matches_filtered'] < hard_filters['min_peaks']: hard_pass = False - if m.match_stats["coverage"] < hard_filters["min_coverage"]: + if m.match_stats['coverage'] < hard_filters['min_coverage']: hard_pass = False - if m.match_stats["precursor_prob"] > hard_filters["max_precursor_intensity"]: + if m.match_stats['precursor_prob'] > hard_filters['max_precursor_intensity']: hard_pass = False if not hard_pass: drop_indices.append(i) continue soft_pass = False - if m.match_stats["num_peak_matches_filtered"] >= soft_filters["desired_peaks"]: + if m.match_stats['num_peak_matches_filtered'] >= soft_filters['desired_peaks']: soft_pass = True if ( - m.match_stats["percent_peak_matches_filtered"] - >= soft_filters["desired_peak_percentage"] + m.match_stats['percent_peak_matches_filtered'] + >= soft_filters['desired_peak_percentage'] ): soft_pass = True - if m.match_stats["coverage"] >= soft_filters["desired_coverage"]: + if m.match_stats['coverage'] >= soft_filters['desired_coverage']: soft_pass = True if not soft_pass: drop_indices.append(i) @@ -142,61 +142,61 @@ def _assign_reference_splits( L = LibraryLoader() df_merged = L.load_from_csv(reference_path) other_dfs = { - "train": df_merged[df_merged["dataset"] == "training"].drop_duplicates( - subset=["group_id"] + 'train': df_merged[df_merged['dataset'] == 'training'].drop_duplicates( + subset=['group_id'] ), - "val": df_merged[df_merged["dataset"] == "validation"].drop_duplicates( - subset=["group_id"] + 'val': df_merged[df_merged['dataset'] == 'validation'].drop_duplicates( + subset=['group_id'] ), - "test": df_merged[df_merged["dataset"] == "test"].drop_duplicates( - subset=["group_id"] + 'test': df_merged[df_merged['dataset'] == 'test'].drop_duplicates( + subset=['group_id'] ), } if casmi16_path: - other_dfs["test"] = pd.concat( - [other_dfs["test"], pd.read_csv(casmi16_path, index_col=[0])] + other_dfs['test'] = pd.concat( + [other_dfs['test'], pd.read_csv(casmi16_path, index_col=[0])] ) if casmi16t_path: - other_dfs["test"] = pd.concat( - [other_dfs["test"], pd.read_csv(casmi16t_path, index_col=[0])] + other_dfs['test'] = pd.concat( + [other_dfs['test'], pd.read_csv(casmi16t_path, index_col=[0])] ) if casmi22_path: - other_dfs["test"] = pd.concat( - [other_dfs["test"], pd.read_csv(casmi22_path, index_col=[0])] + other_dfs['test'] = pd.concat( + [other_dfs['test'], pd.read_csv(casmi22_path, index_col=[0])] ) - other_dfs["test"] = other_dfs["test"].drop_duplicates(subset=["SMILES"]) + other_dfs['test'] = other_dfs['test'].drop_duplicates(subset=['SMILES']) - lookup_table = {"train": set(), "val": set(), "test": set()} + lookup_table = {'train': set(), 'val': set(), 'test': set()} for key, df_x in other_dfs.items(): - df_x["Metabolite"] = df_x["SMILES"].apply(Metabolite) + df_x['Metabolite'] = df_x['SMILES'].apply(Metabolite) for _, data in df_x.iterrows(): - m = data["Metabolite"] + m = data['Metabolite'] lookup_table[key].add((m.ExactMolWeight, m.morganFingerCountOnes)) train, val, test = [], [], [] - for gid in df["group_id"].unique(): - m = df[df["group_id"] == gid].iloc[0]["Metabolite"] + for gid in df['group_id'].unique(): + m = df[df['group_id'] == gid].iloc[0]['Metabolite'] fast_id = (m.ExactMolWeight, m.morganFingerCountOnes) found_match = False - if fast_id in lookup_table["train"]: - for _, data in other_dfs["train"].iterrows(): - if m == data["Metabolite"]: + if fast_id in lookup_table['train']: + for _, data in other_dfs['train'].iterrows(): + if m == data['Metabolite']: train.append(gid) found_match = True break - if not found_match and fast_id in lookup_table["val"]: - for _, data in other_dfs["val"].iterrows(): - if m == data["Metabolite"]: + if not found_match and fast_id in lookup_table['val']: + for _, data in other_dfs['val'].iterrows(): + if m == data['Metabolite']: val.append(gid) found_match = True break - if not found_match and fast_id in lookup_table["test"]: - for _, data in other_dfs["test"].iterrows(): - if m == data["Metabolite"]: + if not found_match and fast_id in lookup_table['test']: + for _, data in other_dfs['test'].iterrows(): + if m == data['Metabolite']: test.append(gid) break - keys = np.unique(df["group_id"].astype(int)) + keys = np.unique(df['group_id'].astype(int)) mask = ~np.isin(keys, train + val + test) unassigned_keys = keys[mask] desired_split_size = int(len(keys) * 0.1) @@ -223,236 +223,236 @@ def _assign_reference_splits( val = np.concatenate((np.array(val), val_keys)) test = np.concatenate((np.array(test), test_keys)) - df["dataset"] = df["group_id"].apply( - lambda x: "training" if x in train else "validation" if x in val else "test" + df['dataset'] = df['group_id'].apply( + lambda x: 'training' if x in train else 'validation' if x in val else 'test' ) - df["datasplit"] = df["dataset"] + df['datasplit'] = df['dataset'] return df def main() -> None: parser = argparse.ArgumentParser( - description="Preprocess MSnLib spectra (full parity with msnlib_loader.ipynb)." + description='Preprocess MSnLib spectra (full parity with msnlib_loader.ipynb).' ) parser.add_argument( - "--msnlib-dir", - default=str(BASE_DIR / "raw"), - help="Directory with MSnLib ms2.mgf files (default: ./raw).", + '--msnlib-dir', + default=str(BASE_DIR / 'raw'), + help='Directory with MSnLib ms2.mgf files (default: ./raw).', ) - parser.add_argument("--version", default="v7", help="MSnLib version (v5/v7).") + parser.add_argument('--version', default='v7', help='MSnLib version (v5/v7).') parser.add_argument( - "--filter-spectype", + '--filter-spectype', action=argparse.BooleanOptionalAction, default=True, - help="Filter spectra by SPECTYPE (default: true).", + help='Filter spectra by SPECTYPE (default: true).', ) parser.add_argument( - "--allowed-spectypes", - default=",".join(DEFAULT_SPECTYPES), - help="Comma-separated list of spectypes to keep.", + '--allowed-spectypes', + default=','.join(DEFAULT_SPECTYPES), + help='Comma-separated list of spectypes to keep.', ) - parser.add_argument("--ppm-num", type=int, default=10) - parser.add_argument("--ce-upper-limit", type=float, default=100.0) - parser.add_argument("--weight-upper-limit", type=float, default=1000.0) + parser.add_argument('--ppm-num', type=int, default=10) + parser.add_argument('--ce-upper-limit', type=float, default=100.0) + parser.add_argument('--weight-upper-limit', type=float, default=1000.0) parser.add_argument( - "--allowed-precursor-modes", - default=",".join(DEFAULT_ALLOWED_PRECURSOR_MODES), - help="Comma-separated precursor modes to keep.", + '--allowed-precursor-modes', + default=','.join(DEFAULT_ALLOWED_PRECURSOR_MODES), + help='Comma-separated precursor modes to keep.', ) parser.add_argument( - "--reference-splits", + '--reference-splits', default=None, - help="Path to reference datasplits CSV (e.g., datasplits_Jan24.csv).", + help='Path to reference datasplits CSV (e.g., datasplits_Jan24.csv).', ) - parser.add_argument("--casmi16", default=None, help="Path to CASMI-16 CSV.") - parser.add_argument("--casmi22", default=None, help="Path to CASMI-22 CSV.") - parser.add_argument("--casmi16t", default=None, help="Path to CASMI-16T CSV.") + parser.add_argument('--casmi16', default=None, help='Path to CASMI-16 CSV.') + parser.add_argument('--casmi22', default=None, help='Path to CASMI-22 CSV.') + parser.add_argument('--casmi16t', default=None, help='Path to CASMI-16T CSV.') parser.add_argument( - "--assign-datasplit", + '--assign-datasplit', action=argparse.BooleanOptionalAction, default=True, - help="Assign train/val/test splits if no reference provided (default: true).", + help='Assign train/val/test splits if no reference provided (default: true).', ) - parser.add_argument("--train-frac", type=float, default=0.8) - parser.add_argument("--val-frac", type=float, default=0.1) - parser.add_argument("--test-frac", type=float, default=0.1) - parser.add_argument("--seed", type=int, default=42) + parser.add_argument('--train-frac', type=float, default=0.8) + parser.add_argument('--val-frac', type=float, default=0.1) + parser.add_argument('--test-frac', type=float, default=0.1) + parser.add_argument('--seed', type=int, default=42) parser.add_argument( - "--verbose", + '--verbose', action=argparse.BooleanOptionalAction, default=True, - help="Print stage-level progress messages (default: true).", + help='Print stage-level progress messages (default: true).', ) parser.add_argument( - "--progress", + '--progress', action=argparse.BooleanOptionalAction, default=True, - help="Show tqdm progress bars for heavy loops (default: true).", + help='Show tqdm progress bars for heavy loops (default: true).', ) parser.add_argument( - "--output", + '--output', default=str(DEFAULT_OUTPUT), - help="Output CSV path (default: ./library.csv).", + help='Output CSV path (default: ./library.csv).', ) args = parser.parse_args() msnlib_dir = Path(args.msnlib_dir) - _log(f"Loading MSnLib files from {msnlib_dir}", args.verbose) + _log(f'Loading MSnLib files from {msnlib_dir}', args.verbose) df = _load_msnlib_dir(msnlib_dir) - _log(f"Loaded {len(df)} raw spectra rows.", args.verbose) - delim = ", " if args.version == "v5" else "," - df["CE_steps"] = _compute_ce_steps(df["COLLISION_ENERGY"], delim) - df["Num_steps"] = df["CE_steps"].apply(len) - df["CE"] = df["CE_steps"].apply(lambda x: sum(x) / len(x) if x else np.nan) + _log(f'Loaded {len(df)} raw spectra rows.', args.verbose) + delim = ', ' if args.version == 'v5' else ',' + df['CE_steps'] = _compute_ce_steps(df['COLLISION_ENERGY'], delim) + df['Num_steps'] = df['CE_steps'].apply(len) + df['CE'] = df['CE_steps'].apply(lambda x: sum(x) / len(x) if x else np.nan) if args.filter_spectype: allowed_spectypes = [ - s.strip() for s in args.allowed_spectypes.split(",") if s.strip() + s.strip() for s in args.allowed_spectypes.split(',') if s.strip() ] before = len(df) - df = df[df["SPECTYPE"].isin(allowed_spectypes)] + df = df[df['SPECTYPE'].isin(allowed_spectypes)] _log( - f"SPECTYPE filter kept {len(df)}/{before} rows: {allowed_spectypes}", + f'SPECTYPE filter kept {len(df)}/{before} rows: {allowed_spectypes}', args.verbose, ) - df["peaks"] = df["peaks"].apply(lambda p: p if isinstance(p, dict) else None) + df['peaks'] = df['peaks'].apply(lambda p: p if isinstance(p, dict) else None) before = len(df) - df = df[df["peaks"].notna()].copy() - _log(f"Rows with valid peaks: {len(df)}/{before}", args.verbose) + df = df[df['peaks'].notna()].copy() + _log(f'Rows with valid peaks: {len(df)}/{before}', args.verbose) tolerance = args.ppm_num * mol_constants.PPM - df["PPM_num"] = args.ppm_num - df["ppm_peak_tolerance"] = tolerance + df['PPM_num'] = args.ppm_num + df['ppm_peak_tolerance'] = tolerance - _log("Constructing Metabolite objects...", args.verbose) - df["Metabolite"] = [ + _log('Constructing Metabolite objects...', args.verbose) + df['Metabolite'] = [ Metabolite(smiles) for smiles in _iter_progress( - df["SMILES"], total=len(df), desc="Metabolites", enabled=args.progress + df['SMILES'], total=len(df), desc='Metabolites', enabled=args.progress ) ] - _log("Building molecular structure graphs...", args.verbose) + _log('Building molecular structure graphs...', args.verbose) for m in _iter_progress( - df["Metabolite"], total=len(df), desc="Build graphs", enabled=args.progress + df['Metabolite'], total=len(df), desc='Build graphs', enabled=args.progress ): m.create_molecular_structure_graph() - _log("Computing graph attributes...", args.verbose) + _log('Computing graph attributes...', args.verbose) for m in _iter_progress( - df["Metabolite"], total=len(df), desc="Graph attrs", enabled=args.progress + df['Metabolite'], total=len(df), desc='Graph attrs', enabled=args.progress ): m.compute_graph_attributes(memory_safe=False) mindex = MetaboliteIndex() - _log("Indexing metabolites and creating fragmentation trees...", args.verbose) - mindex.index_metabolites(df["Metabolite"]) - h_plus = Chem.MolFromSmiles("[H+]") + _log('Indexing metabolites and creating fragmentation trees...', args.verbose) + mindex.index_metabolites(df['Metabolite']) + h_plus = Chem.MolFromSmiles('[H+]') mol_constants.ADDUCT_WEIGHTS.update( { - "[M+2H]-": Descriptors.ExactMolWt(h_plus) - + 1 * Descriptors.ExactMolWt(Chem.MolFromSmiles("[H]")), - "[M+3H]-": Descriptors.ExactMolWt(h_plus) - + 2 * Descriptors.ExactMolWt(Chem.MolFromSmiles("[H]")), + '[M+2H]-': Descriptors.ExactMolWt(h_plus) + + 1 * Descriptors.ExactMolWt(Chem.MolFromSmiles('[H]')), + '[M+3H]-': Descriptors.ExactMolWt(h_plus) + + 2 * Descriptors.ExactMolWt(Chem.MolFromSmiles('[H]')), } ) mindex.create_fragmentation_trees() mindex.add_fragmentation_trees_to_metabolite_list( - df["Metabolite"], graph_mismatch_policy="recompute" + df['Metabolite'], graph_mismatch_policy='recompute' ) - df["group_id"] = df["Metabolite"].apply(lambda x: x.get_id()) + df['group_id'] = df['Metabolite'].apply(lambda x: x.get_id()) df = _reweight_groups(df) - _log("Matching fragments to peaks...", args.verbose) + _log('Matching fragments to peaks...', args.verbose) for metabolite, peaks in _iter_progress( - zip(df["Metabolite"], df["peaks"]), + zip(df['Metabolite'], df['peaks']), total=len(df), - desc="Match fragments", + desc='Match fragments', enabled=args.progress, ): metabolite.match_fragments_to_peaks( - peaks["mz"], - peaks["intensity"], + peaks['mz'], + peaks['intensity'], tolerance=tolerance, match_stats_only=True, ) - df["PEPMASS"] = pd.to_numeric(df["PEPMASS"], errors="coerce") - df["RTINSECONDS"] = pd.to_numeric(df["RTINSECONDS"], errors="coerce") - df["ionization"] = "ESI" - df["instrument"] = "HCD" - df["Precursor_type"] = df["ADDUCT"] + df['PEPMASS'] = pd.to_numeric(df['PEPMASS'], errors='coerce') + df['RTINSECONDS'] = pd.to_numeric(df['RTINSECONDS'], errors='coerce') + df['ionization'] = 'ESI' + df['instrument'] = 'HCD' + df['Precursor_type'] = df['ADDUCT'] metadata_key_map = { - "name": "NAME", - "collision_energy": "CE", - "instrument": "instrument", - "ionization": "ionization", - "precursor_mz": "PEPMASS", - "precursor_mode": "Precursor_type", - "retention_time": "RTINSECONDS", - "ce_steps": "CE_steps", + 'name': 'NAME', + 'collision_energy': 'CE', + 'instrument': 'instrument', + 'ionization': 'ionization', + 'precursor_mz': 'PEPMASS', + 'precursor_mode': 'Precursor_type', + 'retention_time': 'RTINSECONDS', + 'ce_steps': 'CE_steps', } setup_encoder = CovariateFeatureEncoder( feature_list=[ - "collision_energy", - "molecular_weight", - "precursor_mode", - "instrument", + 'collision_energy', + 'molecular_weight', + 'precursor_mode', + 'instrument', ] ) rt_encoder = CovariateFeatureEncoder( - feature_list=["molecular_weight", "precursor_mode", "instrument"] + feature_list=['molecular_weight', 'precursor_mode', 'instrument'] ) - setup_encoder.normalize_features["collision_energy"]["max"] = args.ce_upper_limit - setup_encoder.normalize_features["molecular_weight"]["max"] = ( + setup_encoder.normalize_features['collision_energy']['max'] = args.ce_upper_limit + setup_encoder.normalize_features['molecular_weight']['max'] = ( args.weight_upper_limit ) - rt_encoder.normalize_features["molecular_weight"]["max"] = args.weight_upper_limit + rt_encoder.normalize_features['molecular_weight']['max'] = args.weight_upper_limit - df["summary"] = df.apply( + df['summary'] = df.apply( lambda x: {key: x[name] for key, name in metadata_key_map.items()}, axis=1 ) df.apply( - lambda x: x["Metabolite"].add_metadata(x["summary"], setup_encoder, rt_encoder), + lambda x: x['Metabolite'].add_metadata(x['summary'], setup_encoder, rt_encoder), axis=1, ) allowed_precursors = [ - x.strip() for x in args.allowed_precursor_modes.split(",") if x.strip() + x.strip() for x in args.allowed_precursor_modes.split(',') if x.strip() ] before = len(df) - df = df[df["ADDUCT"].isin(allowed_precursors)] + df = df[df['ADDUCT'].isin(allowed_precursors)] _log( - f"Precursor mode filter kept {len(df)}/{before} rows: {allowed_precursors}", + f'Precursor mode filter kept {len(df)}/{before} rows: {allowed_precursors}', args.verbose, ) - correct_energy = df["Metabolite"].apply( + correct_energy = df['Metabolite'].apply( lambda x: ( - (x.metadata["collision_energy"] <= args.ce_upper_limit) - and (x.metadata["collision_energy"] > 1) + (x.metadata['collision_energy'] <= args.ce_upper_limit) + and (x.metadata['collision_energy'] > 1) ) ) before = len(df) df = df[correct_energy] - _log(f"Collision energy filter kept {len(df)}/{before} rows.", args.verbose) + _log(f'Collision energy filter kept {len(df)}/{before} rows.', args.verbose) - correct_weight = df["Metabolite"].apply( - lambda x: x.metadata["molecular_weight"] <= args.weight_upper_limit + correct_weight = df['Metabolite'].apply( + lambda x: x.metadata['molecular_weight'] <= args.weight_upper_limit ) before = len(df) df = df[correct_weight] - _log(f"Molecular weight filter kept {len(df)}/{before} rows.", args.verbose) + _log(f'Molecular weight filter kept {len(df)}/{before} rows.', args.verbose) before = len(df) df = _apply_hard_soft_filters(df) - _log(f"Peak-match quality filters kept {len(df)}/{before} rows.", args.verbose) + _log(f'Peak-match quality filters kept {len(df)}/{before} rows.', args.verbose) if args.reference_splits: - _log("Assigning datasplits from reference files...", args.verbose) + _log('Assigning datasplits from reference files...', args.verbose) df = _assign_reference_splits( df, args.reference_splits, @@ -462,8 +462,8 @@ def main() -> None: args.seed, ) elif args.assign_datasplit: - _log("Assigning random datasplits...", args.verbose) - group_ids = df["group_id"].unique().tolist() + _log('Assigning random datasplits...', args.verbose) + group_ids = df['group_id'].unique().tolist() rng = np.random.default_rng(args.seed) rng.shuffle(group_ids) n = len(group_ids) @@ -478,34 +478,34 @@ def main() -> None: def _split_label(gid): if gid in train_ids: - return "training" + return 'training' if gid in val_ids: - return "validation" + return 'validation' if gid in test_ids: - return "test" - return "training" + return 'test' + return 'training' - df["datasplit"] = df["group_id"].apply(_split_label) + df['datasplit'] = df['group_id'].apply(_split_label) - if "datasplit" in df.columns: - counts = df["datasplit"].value_counts().to_dict() - _log(f"Split counts: {counts}", args.verbose) + if 'datasplit' in df.columns: + counts = df['datasplit'].value_counts().to_dict() + _log(f'Split counts: {counts}', args.verbose) df = _reweight_groups(df) - if "Metabolite" in df.columns: - df = df.drop(columns=["Metabolite"]) + if 'Metabolite' in df.columns: + df = df.drop(columns=['Metabolite']) - for col in ["peaks", "summary"]: + for col in ['peaks', 'summary']: if col in df.columns: df[col] = df[col].apply( lambda v: json.dumps(v) if isinstance(v, dict) else v ) - _log(f"Writing output to {args.output}", args.verbose) + _log(f'Writing output to {args.output}', args.verbose) df.to_csv(args.output, index=False) - print(f"Wrote {len(df)} rows to {args.output}") + print(f'Wrote {len(df)} rows to {args.output}') -if __name__ == "__main__": +if __name__ == '__main__': main() From 8fd8c8cbbdec1995b9353fddabf0729428d798b3 Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Sun, 19 Apr 2026 09:39:43 +0200 Subject: [PATCH 12/15] Widen pre-commit Ruff file filters --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c52674e..664754f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,6 @@ repos: hooks: - id: ruff args: [--fix] - files: ^(fiora|tests)/.*\.py$ + files: ^((fiora|tests|resources)/.*\.py|scripts/.*)$ - id: ruff-format - files: ^(fiora|tests)/.*\.py$ + files: ^((fiora|tests|resources)/.*\.py|scripts/.*)$ From 82a8f826e73d989cae709ad5fced48dda82bf2f4 Mon Sep 17 00:00:00 2001 From: ynowatzk Date: Thu, 23 Apr 2026 16:52:38 +0200 Subject: [PATCH 13/15] hardcode hydrogen masses (removes warning) --- fiora/MOL/constants.py | 55 +++++++++++++----------------------------- 1 file changed, 17 insertions(+), 38 deletions(-) diff --git a/fiora/MOL/constants.py b/fiora/MOL/constants.py index c32fbcf..0a39676 100644 --- a/fiora/MOL/constants.py +++ b/fiora/MOL/constants.py @@ -1,25 +1,23 @@ from rdkit import Chem -from rdkit.Chem import Descriptors -h_minus = Chem.MolFromSmiles('[H-]') # hydrid -h_plus = Chem.MolFromSmiles('[H+]') # h proton -h_2 = Chem.MolFromSmiles('[HH]') # h2 +H_MINUS_MASS = 1.008373611910 # from Descriptors.ExactMolWt(Chem.MolFromSmiles('[H-]')) +H_PLUS_MASS = 1.007276452090 # from Descriptors.ExactMolWt(Chem.MolFromSmiles('[H+]')) +H2_MASS = 2.015650064000 # from Descriptors.ExactMolWt(Chem.MolFromSmiles('[HH]')) +NEUTRAL_H_MASS = 1.007825032000 # from Descriptors.ExactMolWt(Chem.MolFromSmiles('[H]')) ADDUCT_WEIGHTS = { - '[M+H]+': Descriptors.ExactMolWt(h_plus), # 1.007276, - '[M+H]-': Descriptors.ExactMolWt(h_plus), # TODO might not technically exist + '[M+H]+': H_PLUS_MASS, # 1.007276, + '[M+H]-': H_PLUS_MASS, # TODO might not technically exist '[M+NH4]+': 18.033823, '[M+Na]+': 22.989218, - '[M-H]-': -1 * Descriptors.ExactMolWt(h_plus), + '[M-H]-': -1 * H_PLUS_MASS, # # positvie fragment rearrangements # - '[M-H]+': -1 - * Descriptors.ExactMolWt(h_minus), # Double bond replacing 2 hydrogen atoms + H + '[M-H]+': -1 * H_MINUS_MASS, # Double bond replacing 2 hydrogen atoms + H '[M]+': 0, - '[M-2H]+': -1 * Descriptors.ExactMolWt(h_2), # Loosing proton and hydrid - '[M-3H]+': -1 * Descriptors.ExactMolWt(h_2) - - 1 * Descriptors.ExactMolWt(h_minus), # 2 Double bonds + H + '[M-2H]+': -1 * H2_MASS, # Loosing proton and hydrid + '[M-3H]+': -1 * H2_MASS - 1 * H_MINUS_MASS, # 2 Double bonds + H # experimental cases # "[M-4H]+": -1.007276 * 4, # "[M-5H]+": -1.007276 * 5, @@ -27,33 +25,14 @@ # negative fragment rearrangements # # "[M-H]-": -1*Chem.Descriptors.ExactMolWt(h_plus), # see above - '[M]-': 0, # could be one electron too many - '[M-2H]-': -1 * Descriptors.ExactMolWt(h_2), - '[M-3H]-': -1 * Descriptors.ExactMolWt(h_2) - - 1 * Chem.Descriptors.ExactMolWt(h_plus), - # + '[M]-': 0, # could be one electron too few + '[M-2H]-': -1 * H2_MASS, + '[M-3H]-': -1 * H2_MASS - 1 * H_PLUS_MASS, # Hydrogen gains - # - '[M+2H]+': Descriptors.ExactMolWt(h_plus) - + 1 - * Descriptors.ExactMolWt( - Chem.MolFromSmiles('[H]') - ), # 1 proton + 1 neutral hydrogens - '[M+3H]+': Descriptors.ExactMolWt(h_plus) - + 2 - * Descriptors.ExactMolWt( - Chem.MolFromSmiles('[H]') - ), # 1 proton + 2 neutral hydrogens - '[M+2H]-': Descriptors.ExactMolWt(h_plus) - + 1 - * Descriptors.ExactMolWt( - Chem.MolFromSmiles('[H]') - ), # 1 proton + 2 neutral hydrogens - '[M+3H]-': Descriptors.ExactMolWt(h_plus) - + 2 - * Descriptors.ExactMolWt( - Chem.MolFromSmiles('[H]') - ), # 1 proton + 2 neutral hydrogens + '[M+2H]+': H_PLUS_MASS + 1 * NEUTRAL_H_MASS, # 1 proton + 1 neutral hydrogen + '[M+3H]+': H_PLUS_MASS + 2 * NEUTRAL_H_MASS, # 1 proton + 2 neutral hydrogens + '[M+2H]-': H_PLUS_MASS + 1 * NEUTRAL_H_MASS, # 1 proton + 1 neutral hydrogen + '[M+3H]-': H_PLUS_MASS + 2 * NEUTRAL_H_MASS, # 1 proton + 2 neutral hydrogens } From 2c203719afd0e1c23d36dd09d0d8920b3de78505 Mon Sep 17 00:00:00 2001 From: ynowatzk Date: Tue, 28 Apr 2026 09:45:56 +0200 Subject: [PATCH 14/15] Adjust README file --- README.md | 44 ++++---------------------------------------- 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 4d1b579..a1c40b5 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ By default, an open-source model is selected automatically, and predictions typi ### Models and Resources -Default model checkpoints are packaged under `fiora/resources/models` (Python package: `fiora.resources.models`). The CLI uses these automatically when `--model default` is selected. +Default models are packaged under `fiora/resources/models`. The CLI uses these automatically when `--model default` is selected. Scripts for downloading and preprocessing MSnLib are provided in `resources/data/msnlib` (`download_msnlib.py` and `preprocess_msnlib.py`). @@ -95,14 +95,9 @@ python resources/data/msnlib/preprocess_msnlib.py Use `--record-pattern` to select a different subset, e.g. `--record-pattern "*_pos_*.mgf"`. -### MSnLib Training Parity (Notebook vs CLI) +### MSnLib (Re)training -The training notebooks override categorical feature sets for MSnLib: - -- `instrument`: `["HCD"]` -- `precursor_mode`: `["[M+H]+", "[M-H]-", "[M]+", "[M]-"]` - -To match notebook training results when using `fiora-train`, pass the same overrides: +To (re)train a new FIORA model, use `fiora-train`. For example, to train on MSnLib with the same parameters used for the v1.0 release: ```bash fiora-train \ @@ -119,40 +114,9 @@ To persist per-epoch training history, add `--history-out` (supports `.json` or fiora-train ... --history-out checkpoints/fiora_history.json ``` -`pin_memory` is enabled automatically on CUDA; you can override with `--pin-memory` or `--no-pin-memory`. -`--num-workers` is used for both DataLoader workers and parallel preprocessing (thread-based metabolite graph/peak matching setup) in the training CLI. - -For stronger cosine performance, a common setup is: - -```bash -# Stage 1 -fiora-train \ - -i resources/data/msnlib/library.csv \ - -o checkpoints/fiora_stage1.pt \ - --device cuda:0 \ - --instruments HCD \ - --precursor-modes "[M+H]+,[M-H]-,[M]+,[M]-" \ - --hidden-dimension 384 \ - --residual-connections \ - --no-layer-stacking - -# Stage 2 (optional continuation) -fiora-train \ - -i resources/data/msnlib/library.csv \ - -o checkpoints/fiora.pt \ - --resume checkpoints/fiora_stage1.pt \ - --device cuda:0 \ - --instruments HCD \ - --precursor-modes "[M+H]+,[M-H]-,[M]+,[M]-" \ - --loss weighted_mse \ - --y-label compiled_probsSQRT \ - --learning-rate 5e-5 \ - --epochs 30 -``` - ### Model Evaluation CLI -You can evaluate a trained checkpoint on validation/test splits with: +You can evaluate a trained model on validation/test splits with: ```bash fiora-eval \ From 5d3e844f7a69ad6488f41996d23ca314ff4b0ff4 Mon Sep 17 00:00:00 2001 From: ynowatzk Date: Tue, 28 Apr 2026 10:42:19 +0200 Subject: [PATCH 15/15] Format constants for Ruff --- fiora/MOL/constants.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fiora/MOL/constants.py b/fiora/MOL/constants.py index 0a39676..62cf186 100644 --- a/fiora/MOL/constants.py +++ b/fiora/MOL/constants.py @@ -3,7 +3,9 @@ H_MINUS_MASS = 1.008373611910 # from Descriptors.ExactMolWt(Chem.MolFromSmiles('[H-]')) H_PLUS_MASS = 1.007276452090 # from Descriptors.ExactMolWt(Chem.MolFromSmiles('[H+]')) H2_MASS = 2.015650064000 # from Descriptors.ExactMolWt(Chem.MolFromSmiles('[HH]')) -NEUTRAL_H_MASS = 1.007825032000 # from Descriptors.ExactMolWt(Chem.MolFromSmiles('[H]')) +NEUTRAL_H_MASS = ( + 1.007825032000 # from Descriptors.ExactMolWt(Chem.MolFromSmiles('[H]')) +) ADDUCT_WEIGHTS = { '[M+H]+': H_PLUS_MASS, # 1.007276,