From cbef651f2d2c54df127fe4f937aa4eedd9343fce Mon Sep 17 00:00:00 2001 From: taillades Date: Sun, 19 Oct 2025 13:43:40 -0700 Subject: [PATCH 1/4] Agent formatting --- MANIFEST.in | 3 +- README.md | 132 ------------------------- efold/__init__.py | 5 - efold/api/__init__.py | 1 - efold/api/run.py | 70 +++++++------ efold/cli.py | 49 ++++++---- efold/config.py | 55 ----------- efold/core/__init__.py | 4 - efold/core/batch.py | 90 ++++++++--------- efold/core/callbacks.py | 25 +---- efold/core/dataloader.py | 10 +- efold/core/datamodule.py | 54 +++++----- efold/core/dataset.py | 69 ++++++------- efold/core/datatype.py | 20 ++-- efold/core/embeddings.py | 26 ++--- efold/core/loader.py | 16 ++- efold/core/logger.py | 10 +- efold/core/metrics.py | 25 ++--- efold/core/model.py | 79 +++++++-------- efold/core/path.py | 30 ++---- efold/core/postprocess.py | 97 +++++++++--------- efold/core/sampler.py | 52 +++++----- efold/core/util.py | 18 ++-- efold/core/visualisation.py | 22 ++--- efold/models/__init__.py | 1 - efold/models/cnn.py | 34 ++----- efold/models/efold.py | 130 ++++++++++-------------- efold/models/factory.py | 16 ++- efold/models/ribonanza.py | 23 ++--- efold/models/transformer.py | 41 +++----- efold/models/unet.py | 168 ++++++++++++++++---------------- efold/settings.py | 56 +++++++++++ efold/settings.yaml | 66 +++++++++++++ efold/util/__init__.py | 21 ---- efold/util/format_conversion.py | 50 +++++----- pyproject.toml | 31 +++--- scripts/average_weigths.py | 12 +-- scripts/cnn_template.py | 40 +++++--- scripts/efold_training.py | 24 ++--- scripts/mlp-template.py | 19 ++-- scripts/ribonanza-template.py | 23 ++--- scripts/transformer-template.py | 22 ++--- scripts/unet_training.py | 26 ++--- setup.py | 28 +++--- tests/test_eFold.ipynb | 159 +++++++++++++++++++----------- tests/test_speed_eFold.py | 43 ++++---- 46 files changed, 933 insertions(+), 1062 deletions(-) delete mode 100644 efold/config.py create mode 100644 efold/settings.py create mode 100644 efold/settings.yaml diff --git a/MANIFEST.in b/MANIFEST.in index d54bfb5..0674363 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ -requirements.txt \ No newline at end of file +requirements.txt +include efold/settings.yaml \ No newline at end of file diff --git a/README.md b/README.md index d58641b..e69de29 100644 --- a/README.md +++ b/README.md @@ -1,132 +0,0 @@ -# eFold - -This repo contains the pytorch code for our paper “*Diverse Database and Machine Learning Model to narrow the generalization gap in RNA structure prediction”* - -[[BioRXiv](https://www.biorxiv.org/content/10.1101/2024.01.24.577093v1.full)] [[Data](https://huggingface.co/rouskinlab)] - -## Install - -```bash -pip install efold -``` - - -## Inference mode - -### Using the command line - -From a sequence: - -```bash -efold AAACAUGAGGAUUACCCAUGU -o seq.txt -cat seq.txt - -AAACAUGAGGAUUACCCAUGU -..(((((.((....))))))) -``` - -or a fasta file: - -```bash -efold --fasta example.fasta -``` - -Using different formats: -```bash -efold AAACAUGAGGAUUACCCAUGU -bp # base pairs -efold AAACAUGAGGAUUACCCAUGU -db # dotbracket (default) -``` - -Output can be .json, .csv or .txt -```bash -efold AAACAUGAGGAUUACCCAUGU -o output.csv -``` - -Run help: -```bash -efold -h -``` - -### Using python - -```python ->>> from efold import inference ->>> inference('AAACAUGAGGAUUACCCAUGU', fmt='dotbracket') -..(((((.((....))))))) -``` - -## Inference speed -Tested on a AMD EPYC 7272 12 core processor, with 32GB RAM and a RTX3090 GPU - -![alt text](tests/speed_comparison.jpg) - -## File structure - -```bash -efold/ - api/ # for inference calls - core/ # backend - models/ # where we define eFold and other models - resources/ - efold_weights.py # our best model weights -scripts/ - efold_training.py # our training script - [...] -LICENSE -requirements.txt -pyproject.toml -``` - -## Data - -### List of the datasets we used - -A breakdown of the data we used is summarized [here](https://github.com/rouskinlab/efold_data). All the data is stored on the [HuggingFace](https://huggingface.co/rouskinlab). - -### Get the data - -You can download our datasets using [rouskinHF](https://github.com/rouskinlab/rouskinhf): - -```bash -pip install rouskinhf -``` - -And in your code, write: - -```python ->>> import rouskinhf ->>> data = rouskinhf.get_dataset('ribo500-blast') # look at the dataset names on huggingface -``` - - - -## Reproducing our results - -### Training - -A [training script](scripts/efold_training.py) is provided to train eFold from scratch. - -### Testing - -A [notebook](tests/test_eFold.ipynb) is provided to run eFold inference on the four test sets, compute the F1 score and check the validity of the structures. - - -## Citation - -**Plain text:** - -Albéric A. de Lajarte, Yves J. Martin des Taillades, Colin Kalicki, Federico Fuchs Wightman, Justin Aruda, Dragui Salazar, Matthew F. Allan, Casper L’Esperance-Kerckhoff, Alex Kashi, Fabrice Jossinet, Silvi Rouskin. “Diverse Database and Machine Learning Model to narrow the generalization gap in RNA structure prediction”. bioRxiv 2024.01.24.577093; doi: https://doi.org/10.1101/2024.01.24.577093. 2024 - -**BibTex:** - -``` -@article {Lajarte_Martin_2024, - title = {Diverse Database and Machine Learning Model to narrow the generalization gap in RNA structure prediction}, - author = {Alb{\'e}ric A. de Lajarte and Yves J. Martin des Taillades and Colin Kalicki and Federico Fuchs Wightman and Justin Aruda and Dragui Salazar and Matthew F. Allan and Casper L{\textquoteright}Esperance-Kerckhoff and Alex Kashi and Fabrice Jossinet and Silvi Rouskin}, - year = {2024}, - doi = {10.1101/2024.01.24.577093}, - URL = {https://www.biorxiv.org/content/early/2024/01/25/2024.01.24.577093}, - journal = {bioRxiv} -} - -``` diff --git a/efold/__init__.py b/efold/__init__.py index d99bd05..e69de29 100644 --- a/efold/__init__.py +++ b/efold/__init__.py @@ -1,5 +0,0 @@ -from .models import create_model -from .core import * -from .util import * -from .config import * -from .api import * \ No newline at end of file diff --git a/efold/api/__init__.py b/efold/api/__init__.py index 35095cd..e69de29 100644 --- a/efold/api/__init__.py +++ b/efold/api/__init__.py @@ -1 +0,0 @@ -from .run import run as inference \ No newline at end of file diff --git a/efold/api/run.py b/efold/api/run.py index f21cddf..864baf1 100644 --- a/efold/api/run.py +++ b/efold/api/run.py @@ -1,19 +1,19 @@ +import numpy as np import os -from typing import List, Union -from ..models import create_model import torch from os.path import join, dirname -from ..core import batch -from ..core.embeddings import sequence_to_int -from ..core.postprocess import Postprocess -import numpy as np -from ..util.format_conversion import convert_bp_list_to_dotbracket +from typing import List, Union + +from efold.core import batch, embeddings, postprocess +from efold.models import factory +from efold.util import format_conversion torch.set_default_dtype(torch.float32) -postprocesser = Postprocess() +postprocesser = postprocess.Postprocess() + -def _load_sequences_from_fasta(fasta:str): +def _load_sequences_from_fasta(fasta: str) -> list[str]: with open(fasta, "r") as f: lines = f.readlines() sequences = [] @@ -24,36 +24,44 @@ def _load_sequences_from_fasta(fasta:str): sequences[-1] += line.strip() return sequences -def _predict_structure(model, sequence:str, device='cpu'): - seq = sequence_to_int(sequence).unsqueeze(0) +def _predict_structure(model, sequence: str, device: str = "cpu") -> list[tuple[int, int]]: + seq = embeddings.sequence_to_int(sequence).unsqueeze(0) b = batch.Batch( sequence=seq, reference=[""], length=[len(seq)], - L = len(seq), + L=len(seq), use_error=False, batch_size=1, data_types=["sequence"], - dt_count={"sequence": 1}).to(device) - + dt_count={"sequence": 1}, + ).to(device) + # predict the structure with torch.inference_mode(): pred = model(b) - structure = postprocesser.run(pred['structure'].to('cpu'), b.get('sequence').to('cpu')).numpy().round()[0] + structure = ( + postprocesser.run(pred["structure"].to("cpu"), b.get("sequence").to("cpu")) + .numpy() + .round()[0] + ) # turn into 1-indexed base pairs - return [(b,c) for b, c in (np.stack(np.where(np.triu(structure) == 1)) + 1).T] + return [(b, c) for b, c in (np.stack(np.where(np.triu(structure) == 1)) + 1).T] + -def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): +def run( + arg: Union[str, List[str]] = None, fmt: str = "dotbracket", device: str = None +) -> dict[str, Union[str, list[tuple[int, int]]]]: """Runs the Efold API on the provided sequence or fasta file. - + Args: - arg (str): The sequence or the list of sequences to run Efold on, or the path to a fasta file containing the sequences. - + arg (str): The sequence or the list of sequences to run Efold on, or the path to a fasta file containing the sequences. + Returns: dict: A dictionary containing the sequences as keys and the predicted secondary structures as values. - + Examples: >>> from efold.api.run import run >>> structure = run("GGGAAAUCC") # this is awful, we need to remove the prints @@ -66,9 +74,11 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): No scaling, use preLN Replace GLU with swish for Conv >>> assert structure == {'GGGAAAUCC': [(1, 9), (2, 8)]}, "Test failed: {}".format(structure) - + """ - assert fmt in ["dotbracket", "basepair", 'bp'], "Invalid format. Must be either 'dotbracket' or 'basepair'" + assert fmt in ["dotbracket", "basepair", "bp"], ( + "Invalid format. Must be either 'dotbracket' or 'basepair'" + ) # Check if the input is valid if not arg: @@ -83,7 +93,7 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): sequences = arg else: raise ValueError("Either sequence or fasta must be provided") - + # Get device if not device: if torch.cuda.is_available(): @@ -92,7 +102,7 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): device = torch.device("cpu") # Load best model - model = create_model( + model = factory.create_model( model="efold", ntoken=5, d_model=64, @@ -105,17 +115,19 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): weight_decay=0, gamma=0.995, ) - model.load_state_dict(torch.load(join(dirname(dirname(__file__)), "resources/efold_weights.pt")), strict=False) + model.load_state_dict( + torch.load(join(dirname(dirname(__file__)), "resources/efold_weights.pt")), strict=False + ) model.eval() model = model.to(device) structures = [] - for seq in sequences: + for seq in sequences: structure = _predict_structure(model, seq, device=device) if fmt == "dotbracket": - db_structure = convert_bp_list_to_dotbracket(structure, len(seq)) + db_structure = format_conversion.convert_bp_list_to_dotbracket(structure, len(seq)) if db_structure != None: structure = db_structure structures.append(structure) - return {seq: structure for seq, structure in zip(sequences, structures)} \ No newline at end of file + return {seq: structure for seq, structure in zip(sequences, structures)} diff --git a/efold/cli.py b/efold/cli.py index afb737b..2f4970e 100644 --- a/efold/cli.py +++ b/efold/cli.py @@ -1,34 +1,42 @@ import json import click -from efold.api.run import run - -@click.command('efold') -@click.argument('sequence', required=False, type=str) -@click.option('--fasta', '-f', help='Input FASTA file path') -@click.option('--output', '-o', default='output.txt', help='Output file path (json, txt or csv)', type=click.Path()) -@click.option('--basepair/--dotbracket', '-bp/-db', default=False, help='Output structure format') -@click.option('--help', '-h', is_flag=True, help='Show this message', type=bool) -def cli(sequence, fasta, output, basepair, help): - + +from efold.api import run + + +@click.command("efold") +@click.argument("sequence", required=False, type=str) +@click.option("--fasta", "-f", help="Input FASTA file path") +@click.option( + "--output", + "-o", + default="output.txt", + help="Output file path (json, txt or csv)", + type=click.Path(), +) +@click.option("--basepair/--dotbracket", "-bp/-db", default=False, help="Output structure format") +@click.option("--help", "-h", is_flag=True, help="Show this message", type=bool) +def cli(sequence: str, fasta: str, output: str, basepair: bool, help: bool) -> None: if help: click.echo(cli.get_help(click.Context(cli))) return - - fmt = 'bp' if basepair else 'dotbracket' + + fmt = "bp" if basepair else "dotbracket" if sequence: - result = run(sequence, fmt) + result = run.run(sequence, fmt) elif fasta: - result = run(fasta, fmt) + result = run.run(fasta, fmt) else: click.echo("Please provide either a sequence or a FASTA file.") return - with open(output, 'w') as f: - file_fmt = output.split('.')[-1] - if file_fmt == 'json': + with open(output, "w") as f: + file_fmt = output.split(".")[-1] + if file_fmt == "json": f.write(json.dumps(result, indent=4)) - elif file_fmt == 'csv': + elif file_fmt == "csv": import csv + writer = csv.writer(f) writer.writerows(result.items()) else: @@ -40,5 +48,6 @@ def cli(sequence, fasta, output, basepair, help): click.echo() click.echo(f"Output saved to {output}") -if __name__ == '__main__': - cli() \ No newline at end of file + +if __name__ == "__main__": + cli() diff --git a/efold/config.py b/efold/config.py deleted file mode 100644 index 3387524..0000000 --- a/efold/config.py +++ /dev/null @@ -1,55 +0,0 @@ -from torch import float32, cuda, backends -import torch - -seq2int = {"X": 0, "A": 1, "C": 2, "G": 3, "U": 4} # , 'S': 5, 'E': 6} -# seq2int = {"X": 0, "A": 1, "U": 2, "C": 3, "G": 4} -int2seq = {v: k for k, v in seq2int.items()} - -START_TOKEN = None # seq2int['S'] -END_TOKEN = None # seq2int['E'] -PADDING_TOKEN = seq2int["X"] - -DEFAULT_FORMAT = float32 -torch.set_default_dtype(DEFAULT_FORMAT) -UKN = -1000.0 -VAL_GU = 0.095 -device = ( - "cuda" - if cuda.is_available() - # else "mps" # moi j'aime bien le mps - # if backends.mps.is_available() - else "cpu" -) - -TEST_SETS = { - "structure": ["PDB", "archiveII", "lncRNA_nonFiltered", "viral_fragments"], - "sequence": [], - "dms": [], - "shape": [], -} - - -TEST_SETS_NAMES = [i for j in TEST_SETS.values() for i in j] -DATA_TYPES_TEST_SETS = [k for k, v in TEST_SETS.items() for i in v] - -DATA_TYPES = ["structure", "dms", "shape"] -DATA_TYPES_FORMAT = { - "structure": torch.int32, - "dms": DEFAULT_FORMAT, - "shape": DEFAULT_FORMAT, -} -REFERENCE_METRIC = {"structure": "f1", "dms": "mae", "shape": "mae"} -REF_METRIC_SIGN = {"structure": 1, "dms": -1, "shape": -1} -POSSIBLE_METRICS = { - "structure": ["f1"], # , "mFMI"], - "dms": ["mae", "r2", "pearson"], - "shape": ["mae", "r2", "pearson"], -} - -DTYPE_PER_DATA_TYPE = { - "structure": torch.int32, - "dms": DEFAULT_FORMAT, - "shape": DEFAULT_FORMAT, -} - -torch.set_default_dtype(torch.float32) diff --git a/efold/core/__init__.py b/efold/core/__init__.py index 0f07848..e69de29 100644 --- a/efold/core/__init__.py +++ b/efold/core/__init__.py @@ -1,4 +0,0 @@ -from .datamodule import DataModule -from . import metrics -from .dataset import Dataset -from .postprocess import Postprocess diff --git a/efold/core/batch.py b/efold/core/batch.py index 6ae0614..2367366 100644 --- a/efold/core/batch.py +++ b/efold/core/batch.py @@ -1,33 +1,29 @@ import torch -from torch import tensor import torch.nn.functional as F -from .embeddings import base_pairs_to_pairing_matrix, sequence_to_int -from ..config import device, POSSIBLE_METRICS, UKN -from typing import Dict -from .datatype import data_type_factory -from .util import split_data_type -from torch import cuda, backends +from efold import settings +from efold.core import datatype, embeddings, util -def _pad(arr, L, data_type, accept_none=False): + +def _pad(arr: torch.Tensor, L: int, data_type: str, accept_none: bool = False) -> torch.Tensor: padding_values = { "sequence": 0, - "dms": UKN, - "shape": UKN, + "dms": settings.UKN, + "shape": settings.UKN, } - assert ( - data_type in padding_values.keys() - ), f"Unknown data type {data_type}. If you want to pad a structure, use base_pairs_to_pairing_matrix." + assert data_type in padding_values.keys(), ( + f"Unknown data type {data_type}. If you want to pad a structure, use base_pairs_to_pairing_matrix." + ) if accept_none and arr is None: - return tensor([padding_values[data_type]] * L) + return torch.tensor([padding_values[data_type]] * L) return F.pad(arr, (0, L - len(arr)), value=padding_values[data_type]) -def get_padded_vector(dp, data_type, data_part, L): +def get_padded_vector(dp: dict, data_type: str, data_part: str, L: int) -> torch.Tensor: if getattr(dp, data_type) is None: - return tensor([UKN] * L) + return torch.tensor([settings.UKN] * L) if getattr(getattr(dp, data_type), data_part) is None: - return tensor([UKN] * L) + return torch.tensor([settings.UKN] * L) return _pad(getattr(getattr(dp, data_type), data_part), L, data_type) @@ -45,7 +41,7 @@ def __init__( dms=None, shape=None, structure=None, - device = 'cpu' + device="cpu", ): self.reference = reference self.sequence = sequence @@ -66,7 +62,7 @@ def from_dataset_items( batch_data: list, data_type: str, use_error: bool, - structure_padding_value: float = UKN, + structure_padding_value: float = settings.UKN, ): reference = [dp["reference"] for dp in batch_data] length = [dp["length"] for dp in batch_data] @@ -74,7 +70,7 @@ def from_dataset_items( # move the conversion to the dataset sequence = torch.stack( - [_pad(sequence_to_int(dp["sequence"]), L, "sequence") for dp in batch_data] + [_pad(embeddings.sequence_to_int(dp["sequence"]), L, "sequence") for dp in batch_data] ) batch_size = len(reference) @@ -91,10 +87,10 @@ def from_dataset_items( } for dt in data_type: if dt == "structure": - data[dt] = data_type_factory["batch"][dt]( + data[dt] = datatype.data_type_factory["batch"][dt]( true=torch.stack( [ - base_pairs_to_pairing_matrix( + embeddings.base_pairs_to_pairing_matrix( dp["structure"]["true"], l, padding=L, @@ -113,16 +109,14 @@ def from_dataset_items( true = torch.stack(true) # use error if there's a single non-None error and if the true signal is not None - if use_error and len( - [1 for dp in batch_data if dp[dt]["error"] is not None] - ): + if use_error and len([1 for dp in batch_data if dp[dt]["error"] is not None]): for dp in batch_data: error.append(_pad(dp[dt]["error"], L, dt, accept_none=True)) error = torch.stack(error) else: error = [None] * batch_size - data[dt] = data_type_factory["batch"][dt](true=true, error=error) + data[dt] = datatype.data_type_factory["batch"][dt](true=true, error=error) return cls( reference=reference, @@ -136,12 +130,12 @@ def from_dataset_items( **data, ) - def get(self, data_type, index=None, to_numpy=False): + def get(self, data_type: str, index: int = None, to_numpy: bool = False): if data_type in ["reference", "sequence", "length"]: out = getattr(self, data_type) data_part = None else: - data_part, data_type = split_data_type(data_type) + data_part, data_type = util.split_data_type(data_type) # could be in the dataset but wasn't requested in the dm init if data_type not in self.data_types: @@ -168,7 +162,7 @@ def get(self, data_type, index=None, to_numpy=False): out = out.squeeze().cpu().numpy() return out - def integrate_prediction(self, prediction): + def integrate_prediction(self, prediction: dict) -> None: for data_type, pred in prediction.items(): if getattr(self, data_type) is not None: getattr(self, data_type).pred = pred @@ -176,44 +170,39 @@ def integrate_prediction(self, prediction): setattr( self, data_type, - data_type_factory["batch"][data_type](true=None, pred=pred), + datatype.data_type_factory["batch"][data_type](true=None, pred=pred), ) - def get_pairs(self, data_type, to_numpy=False): + def get_pairs(self, data_type: str, to_numpy: bool = False) -> tuple: return ( self.get("pred_{}".format(data_type), to_numpy=to_numpy), self.get("true_{}".format(data_type), to_numpy=to_numpy), ) - def count(self, data_type): + def count(self, data_type: str) -> int: if data_type in ["reference", "sequence", "length"]: return self.batch_size if not data_type in self.dt_count or getattr(self, data_type) is None: return 0 return self.dt_count[data_type] - def contains(self, data_type): + def contains(self, data_type: str) -> bool: if data_type in ["reference", "sequence", "length"]: return True - data_part, data_type = split_data_type(data_type) + data_part, data_type = util.split_data_type(data_type) if not self.count(data_type): return False if ( - not hasattr( - getattr(self, data_type), data_part - ) # that's more of a sanity check + not hasattr(getattr(self, data_type), data_part) # that's more of a sanity check or getattr(getattr(self, data_type), data_part) is None ): return False return True - def __len__(self): + def __len__(self) -> int: return self.count("sequence") - - # return out - - def __del__(self): + def __del__(self) -> None: del self.dms del self.shape del self.structure @@ -225,27 +214,26 @@ def __del__(self): del self.data_types del self.dt_count del self - + @property - def device(self): + def device(self) -> str: return self._device @device.getter - def device(self): + def device(self) -> str: return self._device @device.setter - def device(self, device): - # assert device exists - if device == 'mps' and not backends.mps.is_available(): + def device(self, device: str) -> None: + if device == "mps" and not torch.backends.mps.is_available(): raise ValueError("MPS is not available on this device.") - if device == 'cuda' and not cuda.is_available(): + if device == "cuda" and not torch.cuda.is_available(): raise ValueError("CUDA is not available on this device.") - for attr in ['dms', 'shape', 'structure', 'sequence']: + for attr in ["dms", "shape", "structure", "sequence"]: if getattr(self, attr) is not None: setattr(self, attr, getattr(self, attr).to(device)) self._device = device - def to(self, device): + def to(self, device: str) -> "Batch": self.device = device return self diff --git a/efold/core/callbacks.py b/efold/core/callbacks.py index 5593d0a..456123e 100644 --- a/efold/core/callbacks.py +++ b/efold/core/callbacks.py @@ -1,27 +1,10 @@ -import os -from lightning import LightningModule, Trainer import lightning.pytorch as pl -import torch -import numpy as np -import pandas as pd from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.utilities import rank_zero_only import wandb -from typing import Any -from .visualisation import plot_factory -from .metrics import metric_factory -from .datamodule import DataModule -from .loader import Loader -from .batch import Batch -from ..config import ( - TEST_SETS_NAMES, - REF_METRIC_SIGN, - REFERENCE_METRIC, - DATA_TYPES_TEST_SETS, - POSSIBLE_METRICS, -) -from .logger import Logger, LocalLogger +from efold import settings +from efold.core import batch, datamodule, loader, logger, metrics, visualisation class ModelCheckpoint(pl.Callback): @@ -42,6 +25,6 @@ def on_validation_end(self, trainer: Trainer, pl_module, dataloader_idx=0): return name = "{}_epoch{}.pt".format(wandb.run.name, trainer.current_epoch) - loader = Loader(path="models/" + name) + loader_obj = loader.Loader(path="models/" + name) # logs what MAE it corresponds to - loader.dump(pl_module) + loader_obj.dump(pl_module) diff --git a/efold/core/dataloader.py b/efold/core/dataloader.py index 705c664..4ac7be9 100644 --- a/efold/core/dataloader.py +++ b/efold/core/dataloader.py @@ -1,6 +1,8 @@ -from torch.utils.data import DataLoader as _DataLoader import torch -from .batch import Batch +from torch.utils.data import DataLoader as _DataLoader + +from efold.core import batch + class DataLoader(_DataLoader): def __init__(self, *args, **kwargs): @@ -8,7 +10,9 @@ def __init__(self, *args, **kwargs): self.to_device = kwargs.pop("to_device") super().__init__(*args, **kwargs) - def transfer_batch_to_device(self, batch: Batch, device: torch.device, dataloader_idx: int) -> Batch: + def transfer_batch_to_device( + self, batch: batch.Batch, device: torch.device, dataloader_idx: int + ) -> batch.Batch: if self.to_device: return batch.to(device) return batch diff --git a/efold/core/datamodule.py b/efold/core/datamodule.py index 1f7862c..8114561 100644 --- a/efold/core/datamodule.py +++ b/efold/core/datamodule.py @@ -1,12 +1,11 @@ -from torch.utils.data import random_split, Subset +import datetime import lightning.pytorch as pl -from typing import Union, List -from .dataset import Dataset -from ..config import TEST_SETS, UKN -from .sampler import sampler_factory -from .dataloader import DataLoader import numpy as np -import datetime +from torch.utils.data import random_split, Subset +from typing import Union, List + +from efold import settings +from efold.core import dataloader, dataset, sampler class DataModule(pl.LightningDataModule): @@ -26,7 +25,7 @@ def __init__( use_error=False, max_len=None, min_len=None, - structure_padding_value=UKN, + structure_padding_value=settings.UKN, tqdm=True, buckets=None, **kwargs, @@ -65,9 +64,9 @@ def __init__( "predict": predict_split, } if strategy in ["ddp", "sorted"]: - assert ( - shuffle_valid == shuffle_train == False - ), "You can't shuffle in ddp or sorted mode. Set shuffle_train and shuffle_valid to 0 or use strategy='random'." + assert shuffle_valid == shuffle_train == False, ( + "You can't shuffle in ddp or sorted mode. Set shuffle_train and shuffle_valid to 0 or use strategy='random'." + ) self.shuffle = { "train": shuffle_train, "valid": shuffle_valid, @@ -103,12 +102,10 @@ def _dataset_merge(self, datasets): return merge def setup(self, stage: str = None): - if stage is None or ( - stage in ["fit", "predict"] and not hasattr(self, "all_datasets") - ): + if stage is None or (stage in ["fit", "predict"] and not hasattr(self, "all_datasets")): self.all_datasets = self._dataset_merge( [ - Dataset.from_local_or_download( + dataset.Dataset.from_local_or_download( name=name, data_type=self.data_type, sort_by_length=self.strategy == "sorted", @@ -129,9 +126,9 @@ def setup(self, stage: str = None): else self.splits["train"] ) assert num_datapoints > 0, "train_split must be greater than 0" - assert num_datapoints <= len( - self.all_datasets - ), "train_split must be less than the number of datapoints" + assert num_datapoints <= len(self.all_datasets), ( + "train_split must be less than the number of datapoints" + ) self.train_set = Subset( self.all_datasets, range(0, num_datapoints), @@ -140,7 +137,7 @@ def setup(self, stage: str = None): self.external_val_set = [] for name in self.external_valid: self.external_val_set.append( - Dataset.from_local_or_download( + dataset.Dataset.from_local_or_download( name=name, data_type=self.data_type, sort_by_length=True, @@ -164,12 +161,13 @@ def setup(self, stage: str = None): def _select_test_dataset(self): return [ - Dataset.from_local_or_download( + dataset.Dataset.from_local_or_download( name=name, data_type=[data_type], **self.dataset_args, ) - for data_type, datasets in TEST_SETS.items() if data_type in self.data_type + for data_type, datasets in settings.TEST_SETS.items() + if data_type in self.data_type for name in datasets ] @@ -179,21 +177,21 @@ def train_dataloader(self): raise ValueError( "When using strategy='ddp', the trainer must be passed to the datamodule" ) - else: # ddp + else: # ddp num_replicas = self.trainer.num_devices rank = self.trainer.local_rank else: num_replicas = 1 rank = 0 - return DataLoader( + return dataloader.DataLoader( self.train_set, shuffle=self.shuffle["train"], num_workers=self.num_workers, collate_fn=self.collate_fn, batch_size=self.batch_size, to_device=self.strategy != "ddp", - sampler=sampler_factory( + sampler=sampler.sampler_factory( dataset=self.train_set, strategy=self.strategy, num_replicas=num_replicas, @@ -210,13 +208,13 @@ def val_dataloader(self): if self.external_valid is not None: for val_set in self.external_val_set: val_dls.append( - DataLoader( + dataloader.DataLoader( val_set, shuffle=self.shuffle["valid"], collate_fn=self.collate_fn, batch_size=self.batch_size, to_device=self.strategy != "ddp", - sampler=sampler_factory( + sampler=sampler.sampler_factory( dataset=val_set, strategy=self.strategy, num_replicas=self.trainer.num_devices, @@ -229,7 +227,7 @@ def val_dataloader(self): def test_dataloader(self): return [ - DataLoader( + dataloader.DataLoader( test_set, num_workers=self.num_workers, collate_fn=test_set.collate_fn, @@ -239,7 +237,7 @@ def test_dataloader(self): ] def predict_dataloader(self): - return DataLoader( + return dataloader.DataLoader( self.predict_set, num_workers=self.num_workers, collate_fn=self.collate_fn, diff --git a/efold/core/dataset.py b/efold/core/dataset.py index f52b815..5810d57 100644 --- a/efold/core/dataset.py +++ b/efold/core/dataset.py @@ -4,13 +4,10 @@ from torch.utils.data import ConcatDataset, Dataset as TorchDataset, Dataset from typing import List -from .batch import Batch from rouskinhf import get_dataset -from .datatype import DMSDataset, SHAPEDataset, StructureDataset -from .embeddings import sequence_to_int -from .util import _pad -from .path import Path -from ..config import UKN + +from efold import settings +from efold.core import batch, datatype, embeddings, path, util class Dataset(TorchDataset): @@ -25,9 +22,9 @@ def __init__( min_len: int, structure_padding_value: float, use_error: bool, - dms: DMSDataset = None, - shape: SHAPEDataset = None, - structure: StructureDataset = None, + dms: datatype.DMSDataset = None, + shape: datatype.SHAPEDataset = None, + structure: datatype.StructureDataset = None, sort_by_length: bool = False, ) -> None: super().__init__() @@ -43,13 +40,13 @@ def __init__( self.structure_padding_value = structure_padding_value self.L = max(self.length) self._remove_sequences_out_of_length_interval(min_len, max_len) - + if sort_by_length: self.sort() def _remove_sequences_out_of_length_interval(self, min_len, max_len): if max_len is None: - max_len = np.inf + max_len = np.inf if min_len is None: min_len = 0 if min_len > max_len: @@ -81,9 +78,7 @@ def __add__(self, other: "Dataset") -> "Dataset": refs=np.concatenate([self.refs, other.refs]).tolist(), length=np.concatenate([self.length, other.length]).tolist(), sequence=np.concatenate([self.sequence, other.sequence]).tolist(), - dms=self.dms + other.dms - if self.dms is not None and other.dms is not None - else None, + dms=self.dms + other.dms if self.dms is not None and other.dms is not None else None, shape=self.shape + other.shape if self.shape is not None and other.shape is not None else None, @@ -101,36 +96,36 @@ def from_local_or_download( use_error: bool = False, max_len=None, min_len=None, - structure_padding_value: float = UKN, + structure_padding_value: float = settings.UKN, sort_by_length: bool = False, tqdm=True, ): - path = Path(name=name) + path_obj = path.Path(name=name) if force_download: - path.clear() + path_obj.clear() - if os.path.exists(path.get_reference()): + if os.path.exists(path_obj.get_reference()): print("Loading dataset from disk") print("Load references \r", end="") - refs = path.load_reference().tolist() + refs = path_obj.load_reference().tolist() print("Load lengths \r", end="") - length = path.load_length().tolist() + length = path_obj.load_length().tolist() L = max(length) print("Load sequences \r", end="") - sequence = path.load_sequence().tolist() + sequence = path_obj.load_sequence().tolist() dms, shape, structure = None, None, None if "dms" in data_type: print("Load dms \r", end="") - dms = path.load_dms() + dms = path_obj.load_dms() if "shape" in data_type: print("Load shape \r", end="") - shape = path.load_shape() + shape = path_obj.load_shape() if "structure" in data_type: print("Load structure \r", end="") - structure = path.load_structure() + structure = path_obj.load_structure() else: data = get_dataset( @@ -142,28 +137,28 @@ def from_local_or_download( print("Dump lengths \r", end="") length = [len(d["sequence"]) for d in data.values()] - path.dump_length(np.array(length)) + path_obj.dump_length(np.array(length)) print("Dump references \r", end="") refs = list(data.keys()) - path.dump_reference(np.array(list(data.keys()))) + path_obj.dump_reference(np.array(list(data.keys()))) L = max(length) print("Dump sequences \r", end="") sequence = [d["sequence"] for d in data.values()] - path.dump_sequence(np.array(sequence)) + path_obj.dump_sequence(np.array(sequence)) print("Dump dms \r", end="") - dms = DMSDataset.from_data_json(data, L, refs) - path.dump_dms(dms) + dms = datatype.DMSDataset.from_data_json(data, L, refs) + path_obj.dump_dms(dms) print("Dump shape \r", end="") - shape = SHAPEDataset.from_data_json(data, L, refs) - path.dump_shape(shape) + shape = datatype.SHAPEDataset.from_data_json(data, L, refs) + path_obj.dump_shape(shape) print("Dump structure \r", end="") - structure = StructureDataset.from_data_json(data, L, refs) - path.dump_structure(structure) + structure = datatype.StructureDataset.from_data_json(data, L, refs) + path_obj.dump_structure(structure) print("Done! ") @@ -205,16 +200,14 @@ def __getitem__(self, index) -> tuple: "length": self.length[index], } for attr in ["dms", "shape", "structure"]: - out[attr] = ( - getattr(self, attr)[index] if getattr(self, attr) != None else None - ) + out[attr] = getattr(self, attr)[index] if getattr(self, attr) != None else None return out def collate_fn(self, batch_data): - batch = Batch.from_dataset_items( + batch_obj = batch.Batch.from_dataset_items( batch_data, self.data_type, use_error=self.use_error, structure_padding_value=self.structure_padding_value, ) - return batch + return batch_obj diff --git a/efold/core/datatype.py b/efold/core/datatype.py index 1bcee59..019cd0c 100644 --- a/efold/core/datatype.py +++ b/efold/core/datatype.py @@ -1,7 +1,7 @@ import torch -from ..config import device, UKN, DTYPE_PER_DATA_TYPE -import torch.nn.functional as F -from .util import _pad + +from efold import settings +from efold.core import util class DataType: @@ -17,7 +17,7 @@ def to(self, device): if hasattr(getattr(self, attr), "to"): setattr(self, attr, getattr(self, attr).to(device)) return self - + def __del__(self): del self.true del self.error @@ -46,9 +46,7 @@ def __getitem__(self, idx): def __add__(self, other): if self.name != other.name: - raise ValueError( - f"Cannot concatenate {self.name} and {other.name} datasets." - ) + raise ValueError(f"Cannot concatenate {self.name} and {other.name} datasets.") if other is None: return self @@ -69,7 +67,7 @@ def __delitem__(self, idx): del self.error[idx] if self.pred is not None: del self.pred[idx] - + def sort(self, idx_sorted): self.true = [self.true[i] for i in idx_sorted] if self.error is not None: @@ -85,16 +83,14 @@ def from_data_json(cls, data_json: dict, L: int, refs: list): values = data_json[ref] if data_type in values: true.append( - torch.tensor( - values[data_type], dtype=DTYPE_PER_DATA_TYPE[data_type] - ) + torch.tensor(values[data_type], dtype=settings.DTYPE_PER_DATA_TYPE[data_type]) ) if data_type != "structure": if "error_{}".format(data_type) in values: error.append( torch.tensor( values["error_{}".format(data_type)], - dtype=DTYPE_PER_DATA_TYPE[data_type], + dtype=settings.DTYPE_PER_DATA_TYPE[data_type], ) ) else: diff --git a/efold/core/embeddings.py b/efold/core/embeddings.py index 189578a..e8d0d07 100644 --- a/efold/core/embeddings.py +++ b/efold/core/embeddings.py @@ -1,24 +1,29 @@ from torch import nn import torch -from ..config import DEFAULT_FORMAT, UKN, seq2int, int2seq -NUM_BASES = len(set(seq2int.values())) +from efold import settings +NUM_BASES = len(set(settings.seq2int.values())) -def sequence_to_int(sequence: str): - return torch.tensor([seq2int[s] for s in sequence], dtype=torch.int64) +def sequence_to_int(sequence: str) -> torch.Tensor: + return torch.tensor([settings.seq2int[s] for s in sequence], dtype=torch.int64) -def int_to_sequence(sequence: torch.tensor): - return "".join([int2seq[i.item()] for i in sequence]) +def int_to_sequence(sequence: torch.Tensor) -> str: + return "".join([settings.int2seq[i.item()] for i in sequence]) -def sequence_to_one_hot(sequence_batch: torch.tensor): - """Converts a sequence to a one-hot encoding""" - return nn.functional.one_hot(sequence_batch, NUM_BASES).type(DEFAULT_FORMAT) +def sequence_to_one_hot(sequence_batch: torch.Tensor) -> torch.Tensor: + return nn.functional.one_hot(sequence_batch, NUM_BASES).type(settings.DEFAULT_FORMAT) -def base_pairs_to_pairing_matrix(base_pairs, sequence_length, padding, pad_value=UKN): + +def base_pairs_to_pairing_matrix( + base_pairs: torch.Tensor, + sequence_length: int, + padding: int, + pad_value: float = settings.UKN, +) -> torch.Tensor: pairing_matrix = torch.ones((padding, padding)) * pad_value if base_pairs is None: return pairing_matrix @@ -28,4 +33,3 @@ def base_pairs_to_pairing_matrix(base_pairs, sequence_length, padding, pad_value pairing_matrix[base_pairs[:, 0], base_pairs[:, 1]] = 1.0 pairing_matrix[base_pairs[:, 1], base_pairs[:, 0]] = 1.0 return pairing_matrix - diff --git a/efold/core/loader.py b/efold/core/loader.py index 8c39946..ff6d33d 100644 --- a/efold/core/loader.py +++ b/efold/core/loader.py @@ -13,31 +13,29 @@ def __init__( makedirs(dirname(self.get_path()), exist_ok=True) @classmethod - def find_best_model(cls, prefix): + def find_best_model(cls, prefix: str) -> "Loader": models = [model for model in listdir("models") if model.startswith(prefix)] if len(models) == 0: return None - models.sort( - key=lambda x: float(x.split("_mae:")[-1].split(".")[0].replace("-", ".")) - ) + models.sort(key=lambda x: float(x.split("_mae:")[-1].split(".")[0].replace("-", "."))) return cls(path="models/" + models[0]) - def get_path(self): + def get_path(self) -> str: return self.path - def get_name(self): + def get_name(self) -> str: return self.path.split("/")[-1].split(".")[0] - def write_in_log(self, epoch, mae): + def write_in_log(self, epoch: int, mae: float) -> "Loader": with open("models/_log.txt", "a") as f: f.write(f"{epoch} {self.get_name()}\t{mae}\n") return self - def load_from_weights(self, safe_load=True): + def load_from_weights(self, safe_load: bool = True) -> dict: if safe_load and os.path.exists(self.get_path()) or not safe_load: return torch.load(self.get_path()) raise FileNotFoundError(f"File {self.get_path()} not found") - def dump(self, model): + def dump(self, model) -> "Loader": torch.save(model.state_dict(), self.get_path()) return self diff --git a/efold/core/logger.py b/efold/core/logger.py index 6c248a4..a6fbeab 100644 --- a/efold/core/logger.py +++ b/efold/core/logger.py @@ -1,9 +1,9 @@ import wandb -from ..config import * -import lightning.pytorch as pl import os +import lightning.pytorch as pl import matplotlib.pyplot as plt -import torchmetrics + +from efold.settings import * class LocalLogger: @@ -17,9 +17,7 @@ def __init__(self, path: str = "local_testing_output", overwrite: bool = False): def test_plot(self, dataloader, data_type, name, plot: plt.Figure, idx=None): # save the wandb Image to a png - plot.savefig( - os.path.join(self.path, f"{dataloader}_{data_type}_{name}_{idx}.png") - ) + plot.savefig(os.path.join(self.path, f"{dataloader}_{data_type}_{name}_{idx}.png")) plt.close(plot) diff --git a/efold/core/metrics.py b/efold/core/metrics.py index 6f2f25b..e64f09e 100644 --- a/efold/core/metrics.py +++ b/efold/core/metrics.py @@ -1,9 +1,8 @@ -import torch -from ..config import UKN, POSSIBLE_METRICS -import torch -from .batch import Batch import numpy as np -from typing import TypedDict +import torch + +from efold import settings +from efold.core import batch # wrapper for metrics @@ -11,7 +10,7 @@ def mask_and_flatten(func): def wrapped(pred, true): if pred is None or true is None: return np.nan - mask = true != UKN + mask = true != settings.UKN if torch.sum(mask) == 0: return np.nan pred = pred[mask] @@ -84,9 +83,7 @@ def r2_score(pred, true): :return: R2 score """ - return ( - 1 - torch.sum((true - pred) ** 2) / torch.sum((true - torch.mean(true)) ** 2) - ).item() + return (1 - torch.sum((true - pred) ** 2) / torch.sum((true - torch.mean(true)) ** 2)).item() @mask_and_flatten @@ -100,9 +97,7 @@ def pearson_coefficient(pred, true): """ return torch.mean( - (pred - torch.mean(pred)) - * (true - torch.mean(true)) - / (torch.std(pred) * torch.std(true)) + (pred - torch.mean(pred)) * (true - torch.mean(true)) / (torch.std(pred) * torch.std(true)) ).item() @@ -135,10 +130,10 @@ def __init__(self, name, data_type=["dms", "shape", "structure"]): self.shape = dict(mae=[], pearson=[], r2=[]) self.structure = dict(f1=[]) - def update(self, batch: Batch): + def update(self, batch: batch.Batch): for dt in self.data_type: pred, true = batch.get_pairs(dt) - for metric in POSSIBLE_METRICS[dt]: + for metric in settings.POSSIBLE_METRICS[dt]: self._add_metric(dt, metric, metric_factory[metric](pred, true)) return self @@ -146,7 +141,7 @@ def compute(self) -> dict: out = {} for dt in self.data_type: out[dt] = {} - for metric in POSSIBLE_METRICS[dt]: + for metric in settings.POSSIBLE_METRICS[dt]: out[dt][metric] = self._get_nanmean(dt, metric) return out diff --git a/efold/core/model.py b/efold/core/model.py index f854e95..f89886e 100644 --- a/efold/core/model.py +++ b/efold/core/model.py @@ -1,17 +1,12 @@ from typing import Any import lightning.pytorch as pl from lightning.pytorch.utilities.types import STEP_OUTPUT -import torch.nn as nn import torch -from ..config import device, UKN, TEST_SETS_NAMES +import torch.nn as nn import torch.nn.functional as F -from .batch import Batch -from torchmetrics import R2Score, PearsonCorrCoef, MeanAbsoluteError, F1Score -from .metrics import MetricsStack -from .datamodule import DataModule -import time -from .postprocess import Postprocess +from efold import settings +from efold.core import batch, metrics, postprocess METRIC_ARGS = dict(dist_sync_on_step=True) @@ -45,16 +40,16 @@ def __init__(self, lr: float, optimizer_fn, weight_data: bool = False, **kwargs) self.automatic_optimization = True self.weight_data = weight_data - self.save_hyperparameters(ignore=['loss_fn']) - self.lossBCE = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([300])).to(device) + self.save_hyperparameters(ignore=["loss_fn"]) + self.lossBCE = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([300])).to(settings.device) # Metrics self.metrics_stack = None self.tic = None - self.test_results = {'reference':[], 'sequence':[] ,'structure':[]} + self.test_results = {"reference": [], "sequence": [], "structure": []} - self.postprocesser = Postprocess() + self.postprocesser = postprocess.Postprocess() def configure_optimizers(self): optimizer = self.optimizer_fn( @@ -69,7 +64,7 @@ def configure_optimizers(self): scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.gamma) return [optimizer], [scheduler] - def _loss_signal(self, batch: Batch, data_type: str): + def _loss_signal(self, batch: batch.Batch, data_type: str): assert data_type in [ "dms", "shape", @@ -85,7 +80,7 @@ def _loss_signal(self, batch: Batch, data_type: str): ## vv MSE loss vv ## mask = torch.zeros_like(true) - mask[true != UKN] = 1 + mask[true != settings.UKN] = 1 loss = F.mse_loss(pred * mask, true * mask) non_zeros = (mask == 1).sum() / mask.numel() @@ -96,13 +91,13 @@ def _loss_signal(self, batch: Batch, data_type: str): assert not torch.isnan(loss), "Loss is NaN for {}".format(data_type) return loss - def _loss_structure(self, batch: Batch): + def _loss_structure(self, batch: batch.Batch): pred, true = batch.get_pairs("structure") loss = self.lossBCE(pred, true) assert not torch.isnan(loss), "Loss is NaN for structure" return loss - def loss_fn(self, batch: Batch): + def loss_fn(self, batch: batch.Batch): count = {k: v for k, v in batch.dt_count.items() if k in self.data_type_output} losses = {} if "dms" in count.keys(): @@ -127,7 +122,7 @@ def _clean_predictions(self, batch, predictions): predictions[data_type] = torch.clip(predictions[data_type], min=0, max=1) return predictions - def training_step(self, batch: Batch, batch_idx: int): + def training_step(self, batch: batch.Batch, batch_idx: int): predictions = self.forward(batch) batch.integrate_prediction(predictions) loss = self.loss_fn(batch)[0] @@ -137,13 +132,11 @@ def training_step(self, batch: Batch, batch_idx: int): def on_validation_start(self): val_dl_names = self.trainer.datamodule.external_valid self.metrics_stack = [ - MetricsStack(name=name, data_type=self.data_type_output) + metrics.MetricsStack(name=name, data_type=self.data_type_output) for name in val_dl_names ] - def on_train_batch_end( - self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int - ) -> None: + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: del outputs del batch if batch_idx % 100 == 0: @@ -152,10 +145,12 @@ def on_train_batch_end( def on_train_end(self) -> None: torch.cuda.empty_cache() - def validation_step(self, batch: Batch, batch_idx: int, dataloader_idx=0): + def validation_step(self, batch: batch.Batch, batch_idx: int, dataloader_idx=0): predictions = self.forward(batch) - - predictions['structure'] = self.postprocesser.run(predictions['structure'], batch.get('sequence')) + + predictions["structure"] = self.postprocesser.run( + predictions["structure"], batch.get("sequence") + ) batch.integrate_prediction(predictions) # loss, losses = self.loss_fn(batch) @@ -185,14 +180,20 @@ def on_validation_epoch_end(self) -> None: self.metrics_stack = None torch.cuda.empty_cache() - def test_step(self, batch: Batch, batch_idx: int, dataloader_idx=0): + def test_step(self, batch: batch.Batch, batch_idx: int, dataloader_idx=0): predictions = self.forward(batch) - predictions['structure'] = self.postprocesser.run(predictions['structure'], batch.get('sequence')) + predictions["structure"] = self.postprocesser.run( + predictions["structure"], batch.get("sequence") + ) - from ..config import int2seq - self.test_results['reference'] += batch.get('reference') - self.test_results['sequence'] += [''.join([int2seq[base] for base in seq]) for seq in batch.get('sequence').detach().tolist()] - self.test_results['structure'] += predictions['structure'].tolist() + from efold import settings + + self.test_results["reference"] += batch.get("reference") + self.test_results["sequence"] += [ + "".join([settings.int2seq[base] for base in seq]) + for seq in batch.get("sequence").detach().tolist() + ] + self.test_results["structure"] += predictions["structure"].tolist() predictions = self._clean_predictions(batch, predictions) batch.integrate_prediction(predictions) @@ -201,12 +202,12 @@ def on_test_batch_end( self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: # push the metric directly - metric_pack = MetricsStack( - name=TEST_SETS_NAMES[dataloader_idx], + metric_pack = metrics.MetricsStack( + name=settings.TEST_SETS_NAMES[dataloader_idx], data_type=self.data_type_output, - ) - for dt, metrics in metric_pack.update(batch).compute().items(): - for name, metric in metrics.items(): + ) + for dt, metric_dict in metric_pack.update(batch).compute().items(): + for name, metric in metric_dict.items(): self.log( f"test/{metric_pack.name}/{dt}/{name}", float(metric), @@ -218,16 +219,16 @@ def on_test_batch_end( def on_test_epoch_end(self) -> None: torch.cuda.empty_cache() - + def on_test_end(self) -> None: - import pandas as pd + df = pd.DataFrame(self.test_results) - df.to_feather('test_results.feather') + df.to_feather("test_results.feather") torch.cuda.empty_cache() - def predict_step(self, batch: Batch, batch_idx: int): + def predict_step(self, batch: batch.Batch, batch_idx: int): predictions = self.forward(batch) predictions = self._clean_predictions(batch, predictions) batch.integrate_prediction(predictions) diff --git a/efold/core/path.py b/efold/core/path.py index df881ca..50a70c4 100644 --- a/efold/core/path.py +++ b/efold/core/path.py @@ -49,36 +49,30 @@ def get_reference(self) -> str: """Returns the path to the references.txt file.""" return join(self.get_main_folder(), "references.npy") - def load_reference(self): - """Returns the list of references.""" + def load_reference(self) -> np.ndarray: return np.load(self.get_reference(), allow_pickle=True) - def dump_reference(self, references): - """Dumps the list of references.""" + def dump_reference(self, references: np.ndarray) -> None: np.save(self.get_reference(), references) def get_sequence(self) -> str: """Returns the path to the sequences.txt file.""" return join(self.get_main_folder(), "sequences.npy") - def load_sequence(self): - """Returns the list of sequences.""" + def load_sequence(self) -> np.ndarray: return np.load(self.get_sequence(), allow_pickle=True) - def dump_sequence(self, sequences): - """Dumps the list of sequences.""" + def dump_sequence(self, sequences: np.ndarray) -> None: np.save(self.get_sequence(), sequences) def get_length(self) -> str: """Returns the path to the lengths.txt file.""" return join(self.get_main_folder(), "lengths.npy") - def load_length(self): - """Returns the list of lengths.""" + def load_length(self) -> np.ndarray: return np.load(self.get_length(), allow_pickle=True) - def dump_length(self, lengths): - """Dumps the list of lengths.""" + def dump_length(self, lengths: np.ndarray) -> None: np.save(self.get_length(), lengths) def get_dms(self) -> str: @@ -86,14 +80,12 @@ def get_dms(self) -> str: return join(self.get_main_folder(), "dms.pkl") def load_dms(self): - """Returns the list of dms.""" if not os.path.exists(self.get_dms()): return None return pickle.load(open(self.get_dms(), "rb")) @dont_dump_none - def dump_dms(self, val) -> str: - """Dumps the list of dms.""" + def dump_dms(self, val) -> None: pickle.dump(val, open(self.get_dms(), "wb")) def get_shape(self) -> str: @@ -101,14 +93,12 @@ def get_shape(self) -> str: return join(self.get_main_folder(), "shapes.pkl") def load_shape(self): - """Returns the list of shapes.""" if not os.path.exists(self.get_shape()): return None return pickle.load(open(self.get_shape(), "rb")) @dont_dump_none - def dump_shape(self, val): - """Dumps the list of shapes.""" + def dump_shape(self, val) -> None: pickle.dump(val, open(self.get_shape(), "wb")) def get_structure(self) -> str: @@ -116,12 +106,10 @@ def get_structure(self) -> str: return join(self.get_main_folder(), "structures.pkl") def load_structure(self): - """Returns the list of structures.""" if not os.path.exists(self.get_structure()): return None return pickle.load(open(self.get_structure(), "rb")) @dont_dump_none - def dump_structure(self, val): - """Dumps the list of structures.""" + def dump_structure(self, val) -> None: pickle.dump(val, open(self.get_structure(), "wb")) diff --git a/efold/core/postprocess.py b/efold/core/postprocess.py index 1448544..0903207 100644 --- a/efold/core/postprocess.py +++ b/efold/core/postprocess.py @@ -1,13 +1,13 @@ -import torch -import math import numpy as np +import torch import torch.nn.functional as F from scipy.optimize import linear_sum_assignment -from ..config import seq2int -class Constraints: +from efold import settings - """ Compute a mask of constraints to remove sharp loops and non-canonical pairs from the input matrix + +class Constraints: + """Compute a mask of constraints to remove sharp loops and non-canonical pairs from the input matrix Args: - input_matrix (torch.Tensor): n x n matrix of base pair probabilities @@ -17,48 +17,48 @@ class Constraints: Example: >>> inpt = torch.tensor([[0.1, 0.6, 0.8],[0.6, 0.1, 0.9],[0.8, 0.9, 0.1]]) - >>> sequence = torch.tensor([seq2int[a] for a in "GCU"]) + >>> sequence = torch.tensor([settings.seq2int[a] for a in "GCU"]) >>> out = Constraints().apply_constraints(inpt, sequence=sequence, min_hairpin_length=0, canonical_only=True) >>> assert (out == torch.tensor([[0.0, 0.6, 0.8],[0.6, 0.0, 0.0],[0.8, 0.0, 0.0]])).all(), "The output is not as expected: {}".format(out) """ - def apply_constraints(self, input_matrix, min_hairpin_length=3, canonical_only=True, sequence=None): - + def apply_constraints( + self, input_matrix, min_hairpin_length=3, canonical_only=True, sequence=None + ): # mask elements of the diagonals and sub-diagonals mask = self.mask_sharpLoops(input_matrix, min_hairpin_length) # Mask elements of the matrix that are not A-U, G-C, or G-U pairs using the sequence - if canonical_only: mask *= self.mask_nonCanonical(sequence) - + if canonical_only: + mask *= self.mask_nonCanonical(sequence) + return input_matrix * mask def mask_sharpLoops(self, input_matrix, min_hairpin_length): - # mask elements of the diagonals and sub-diagonals - mask = np.tri(input_matrix.shape[0], k=-min_hairpin_length-1).astype(int) + mask = np.tri(input_matrix.shape[0], k=-min_hairpin_length - 1).astype(int) return torch.tensor(mask + mask.T, device=input_matrix.device) def mask_nonCanonical(self, sequence): - # Embed sequence - if type(sequence) == str: sequence = torch.tensor([seq2int[a] for a in sequence]) - + if type(sequence) == str: + sequence = torch.tensor([settings.seq2int[a] for a in sequence]) + # make the pairing matrix sequence = sequence.reshape(-1, 1) pair_of_bases = sequence + sequence.T # find the allowable pairs allowable_pair = set() - for pair in ["GU", "GC", "AU"]: allowable_pair.add(seq2int[pair[0]] + seq2int[pair[1]]) + for pair in ["GU", "GC", "AU"]: + allowable_pair.add(settings.seq2int[pair[0]] + settings.seq2int[pair[1]]) allowable_pair = torch.tensor(list(allowable_pair), device=pair_of_bases.device) return torch.isin(pair_of_bases, allowable_pair).int() - class HungarianAlgorithm: - def run(self, bppm, threshold=0.5): """Runs the Hungarian algorithm on the input bppm matrix @@ -84,25 +84,25 @@ def run(self, bppm, threshold=0.5): >>> out = HungarianAlgorithm().run(torch.tensor([[0., 0.6, 0.8],[0.6, 0., 0.9],[0.8, 0.9, 0.]]) ) >>> assert (out == [[0., 1., 1.],[1., 0., 1.],[1., 1., 0.]]).all(), "The output is not as expected: {}".format(out) """ - + assert len(bppm.shape) == 2, "The input bppm matrix should be n x n" assert bppm.shape[0] == bppm.shape[1], "The input bppm matrix should be n x n" assert self.is_symmetric(bppm), f"The input bppm matrix should be symmetric, {bppm}" - + # just work with numpy (needed for the optimization step) - if type(bppm)==torch.Tensor: + if type(bppm) == torch.Tensor: device = bppm.device bppm = bppm.cpu().numpy() - - # run hungarian algorithm - bp_matrix = np.zeros(bppm.shape) - + + # run hungarian algorithm + bp_matrix = np.zeros(bppm.shape) + # run hungarian algorithm only on rows and columns that have at least one value greater than threshold compression_idx = self._pairable_bases(bppm, threshold) compressed_bppm = bppm[compression_idx][:, compression_idx] - row_ind, col_ind = self._hungarian_algorithm(compressed_bppm) - + row_ind, col_ind = self._hungarian_algorithm(compressed_bppm) + # convert the result to the original sized matrix for compressed_row, compressed_col in zip(row_ind, col_ind): bppm_row, bppm_col = compression_idx[compressed_row], compression_idx[compressed_col] @@ -111,31 +111,29 @@ def run(self, bppm, threshold=0.5): bp_matrix[bppm_col, bppm_row] = 1 return torch.tensor(bp_matrix, device=device) - + def _hungarian_algorithm(self, cost_matrix): """Returns the row and column indices of the optimal assignment using the Hungarian algorithm""" row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True) return row_ind, col_ind - + def _pairable_bases(self, bppm, threshold): """Returns the indices of rows that have at least one value greater than threshold - + Example: >>> assert (HungarianAlgorithm()._pairable_bases(np.array([[0.0, 0.0, 0.0], [0.0, 0.6, 0.8], [0.0, 0.8, 0.9]]), 0.01) == np.array([1, 2])).all(), "The output is not as expected" """ return np.where((bppm > threshold).any(axis=0))[0] - + def is_symmetric(self, bppm): return (bppm == bppm.transpose(1, 0)).all() - class UFold_processing: - def run(self, bppm): return self.postprocess(u=bppm) - - def postprocess(self, u, lr_min=0.01, lr_max=0.1, num_itr=100, rho=1.6, with_l1=True,s=1.5): + + def postprocess(self, u, lr_min=0.01, lr_max=0.1, num_itr=100, rho=1.6, with_l1=True, s=1.5): """ :param u: utility matrix, u is assumed to be symmetric :param lr_min: learning rate for minimization step @@ -145,10 +143,11 @@ def postprocess(self, u, lr_min=0.01, lr_max=0.1, num_itr=100, rho=1.6, with_l1= :param with_l1: :return: """ + def soft_sign(x): k = 1 - return 1.0/(1.0+torch.exp(-2*k*x)) - + return 1.0 / (1.0 + torch.exp(-2 * k * x)) + def contact_a(a_hat, m): a = a_hat * a_hat a = (a + torch.transpose(a, -1, -2)) / 2 @@ -167,8 +166,9 @@ def contact_a(a_hat, m): # gradient descent for t in range(num_itr): - - grad_a = (lmbd * soft_sign(torch.sum(contact_a(a_hat, m), dim=-1) - 1)).unsqueeze_(-1).expand(u.shape) - u / 2 + grad_a = (lmbd * soft_sign(torch.sum(contact_a(a_hat, m), dim=-1) - 1)).unsqueeze_( + -1 + ).expand(u.shape) - u / 2 grad = a_hat * m * (grad_a + torch.transpose(grad_a, -1, -2)) a_hat -= lr_min * grad lr_min = lr_min * 0.99 @@ -192,34 +192,33 @@ def contact_a(a_hat, m): a = (a + torch.transpose(a, -1, -2)) / 2 a = a * m return a - -class Postprocess: +class Postprocess: def __init__(self, threshold=0.5, canonical_only=True, min_hairpin_length=3): self.threshold = threshold self.canonical_only = canonical_only self.min_hairpin_length = min_hairpin_length def run(self, bppms, sequence): - if len(bppms.shape) == 2: bppms = bppms.unsqueeze(0) pairing_matrices = [] for bppm in bppms: + pairing_matrix = Constraints().apply_constraints( + bppm, + sequence=sequence, + min_hairpin_length=self.min_hairpin_length, + canonical_only=self.canonical_only, + ) - pairing_matrix = Constraints().apply_constraints(bppm, sequence=sequence, - min_hairpin_length=self.min_hairpin_length, - canonical_only=self.canonical_only) - pairing_matrix_UFold = UFold_processing().run(pairing_matrix) - if not pairing_matrix_UFold.isnan().any(): pairing_matrix = pairing_matrix_UFold - + if not pairing_matrix_UFold.isnan().any(): + pairing_matrix = pairing_matrix_UFold pairing_matrix = HungarianAlgorithm().run(pairing_matrix, threshold=self.threshold) pairing_matrices.append(pairing_matrix) return (torch.stack(pairing_matrices) > self.threshold).type(torch.int) - diff --git a/efold/core/sampler.py b/efold/core/sampler.py index d76c1ac..fbf44d2 100644 --- a/efold/core/sampler.py +++ b/efold/core/sampler.py @@ -1,5 +1,6 @@ from torch.utils.data import Sampler, Subset import numpy as np + # from random import shuffle from torch.utils.data import Dataset from typing import Union, Optional, TypeVar, Iterator @@ -8,10 +9,10 @@ import torch import os -T_co = TypeVar('T_co', covariant=True) +T_co = TypeVar("T_co", covariant=True) -class DDPSampler(Sampler): +class DDPSampler(Sampler): r"""Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with @@ -58,12 +59,17 @@ class DDPSampler(Sampler): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader) - """ - def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, - rank: Optional[int] = None, shuffle: bool = True, - seed: int = os.environ.get('PL_GLOBAL_SEED', 0), - drop_last: bool = False) -> None: - + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = os.environ.get("PL_GLOBAL_SEED", 0), + drop_last: bool = False, + ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -74,7 +80,8 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( - f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}] because num_replicas={num_replicas}") + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}] because num_replicas={num_replicas}" + ) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -94,12 +101,14 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed - + if isinstance(dataset, Subset): - self.length = dataset.dataset.length[slice(dataset.indices.start, dataset.indices.stop, dataset.indices.step)] + self.length = dataset.dataset.length[ + slice(dataset.indices.start, dataset.indices.stop, dataset.indices.step) + ] elif isinstance(dataset, Dataset): - self.length = dataset.length - + self.length = dataset.length + def __iter__(self) -> Iterator[T_co]: # deterministically shuffle based on epoch and seed if self.shuffle: @@ -118,24 +127,24 @@ def __iter__(self) -> Iterator[T_co]: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples # sort by length - lengths = [self.length[i] for i in indices] + lengths = [self.length[i] for i in indices] length_order = np.argsort(lengths) indices = [indices[i] for i in length_order] - + # shuffle them again to avoid having the samples sorted by length g = torch.Generator() g.manual_seed(self.seed + self.epoch + 42) deterministic_order = torch.randperm(len(indices), generator=g).tolist() indices = [indices[i] for i in deterministic_order] - + return iter(indices) def __len__(self) -> int: @@ -156,14 +165,13 @@ def set_epoch(self, epoch: int) -> None: def sampler_factory( dataset: Union[Dataset, Subset], strategy: str, - seed:int = os.environ.get('PL_GLOBAL_SEED', 0), + seed: int = os.environ.get("PL_GLOBAL_SEED", 0), num_replicas: Optional[int] = None, rank: Optional[int] = None, ): - if strategy in ['random', 'sorted']: + if strategy in ["random", "sorted"]: return None - elif strategy == 'ddp': + elif strategy == "ddp": return DDPSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=True, seed=seed) else: raise ValueError(f"Invalid strategy value: {strategy}") - \ No newline at end of file diff --git a/efold/core/util.py b/efold/core/util.py index 0bd09d2..4072c9b 100644 --- a/efold/core/util.py +++ b/efold/core/util.py @@ -1,25 +1,25 @@ -from ..config import UKN -from .embeddings import base_pairs_to_pairing_matrix -import torch.nn.functional as F -from torch import tensor import torch +import torch.nn.functional as F + +from efold import settings +from efold.core import embeddings -def _pad(arr, L, data_type): +def _pad(arr: torch.Tensor, L: int, data_type: str) -> torch.Tensor: padding_values = { "sequence": 0, - "dms": UKN, - "shape": UKN, + "dms": settings.UKN, + "shape": settings.UKN, } if data_type == "structure": - return base_pairs_to_pairing_matrix(arr, L) + return embeddings.base_pairs_to_pairing_matrix(arr, L) else: if isinstance(arr, list): arr = torch.tensor(arr) return F.pad(arr, (0, L - arr.shape[1]), value=padding_values[data_type]) -def split_data_type(data_type): +def split_data_type(data_type: str) -> tuple[str, str]: if "_" not in data_type: data_part = "true" else: diff --git a/efold/core/visualisation.py b/efold/core/visualisation.py index 01c78ef..0a6f281 100644 --- a/efold/core/visualisation.py +++ b/efold/core/visualisation.py @@ -1,10 +1,10 @@ from matplotlib import pyplot as plt import numpy as np -import wandb -from .metrics import r2_score, mae_score, pearson_coefficient -from ..config import UKN from rouskinhf import int2seq +from efold import settings +from efold.core import metrics + matplotlib_colors = [ "red", "blue", @@ -38,9 +38,9 @@ def plot_signal( fig, ax = plt.subplots() # Compute metrics while you still have tensors - r2 = r2_score(pred=pred, true=true) - r = pearson_coefficient(pred=pred, true=true) - mae = mae_score(pred=pred, true=true) + r2 = metrics.r2_score(pred=pred, true=true) + r = metrics.pearson_coefficient(pred=pred, true=true) + mae = metrics.mae_score(pred=pred, true=true) # Base position with no coverage or G/U base are removed def chop_array(x): @@ -50,7 +50,7 @@ def known_bases_to_list(x, mask): return x[mask].cpu().numpy() pred, true, sequence = chop_array(pred), chop_array(true), chop_array(sequence) - mask = true != UKN + mask = true != settings.UKN true, pred, sequence = ( known_bases_to_list(true, mask), known_bases_to_list(pred, mask), @@ -129,11 +129,7 @@ def plot_structure(pred, true, **kwargs): plot_factory = { - ("dms", "scatter"): lambda *args, **kwargs: plot_signal( - *args, **kwargs, data_type="DMS" - ), - ("shape", "scatter"): lambda *args, **kwargs: plot_signal( - *args, **kwargs, data_type="SHAPE" - ), + ("dms", "scatter"): lambda *args, **kwargs: plot_signal(*args, **kwargs, data_type="DMS"), + ("shape", "scatter"): lambda *args, **kwargs: plot_signal(*args, **kwargs, data_type="SHAPE"), ("structure", "heatmap"): plot_structure, } diff --git a/efold/models/__init__.py b/efold/models/__init__.py index 594023d..e69de29 100644 --- a/efold/models/__init__.py +++ b/efold/models/__init__.py @@ -1 +0,0 @@ -from .factory import create_model diff --git a/efold/models/cnn.py b/efold/models/cnn.py index d664df3..2a02820 100644 --- a/efold/models/cnn.py +++ b/efold/models/cnn.py @@ -1,19 +1,13 @@ +import numpy as np import torch from torch import nn, Tensor -from torch.nn import TransformerEncoderLayer -import numpy as np -import os -import sys -from ..core.model import Model -from ..core.batch import Batch from einops import rearrange import torch.nn.functional as F -dir_name = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(dir_name, "..")) +from efold.core import batch, model -class CNN(Model): +class CNN(model.Model): def __init__( self, ntoken: int, @@ -89,7 +83,7 @@ def __init__( # nn.Linear(d_model, 1), # ) - def forward(self, batch: Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> Tensor: """ Args: src: Tensor, shape [seq_len, batch_size] @@ -105,9 +99,7 @@ def forward(self, batch: Batch) -> Tensor: # Outer concatenation matrix = x.unsqueeze(1).repeat(1, x.shape[1], 1, 1) # (N, L, L, d_cnn/2) - matrix = torch.cat( - (matrix, matrix.permute(0, 2, 1, 3)), dim=-1 - ) # (N, L, L, d_cnn) + matrix = torch.cat((matrix, matrix.permute(0, 2, 1, 3)), dim=-1) # (N, L, L, d_cnn) # Resnet layers matrix = self.res_layers(matrix.permute(0, 3, 1, 2)) # (N, d_cnn//8, L, L) @@ -225,9 +217,7 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): self.res_blocks = nn.Sequential(*self.res_layers) # Adapter to change depth - self.conv_output = nn.Conv2d( - dim_in, dim_out, kernel_size=7, padding=3, bias=True - ) + self.conv_output = nn.Conv2d(dim_in, dim_out, kernel_size=7, padding=3, bias=True) def forward(self, x: Tensor) -> Tensor: x = self.res_blocks(x) @@ -252,14 +242,10 @@ def __init__( self.bn1 = nn.BatchNorm2d(inplanes) self.relu1 = nn.ReLU(inplace=True) - self.conv1 = conv3x3( - inplanes, planes, dilation=dilation1, kernel_size=kernel_size - ) + self.conv1 = conv3x3(inplanes, planes, dilation=dilation1, kernel_size=kernel_size) self.dropout = nn.Dropout(p=dropout) self.relu2 = nn.ReLU(inplace=True) - self.conv2 = conv3x3( - planes, planes, dilation=dilation2, kernel_size=kernel_size - ) + self.conv2 = conv3x3(planes, planes, dilation=dilation2, kernel_size=kernel_size) def forward(self, x: Tensor) -> Tensor: identity = x @@ -276,9 +262,7 @@ def forward(self, x: Tensor) -> Tensor: return out -def conv3x3( - in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3 -) -> nn.Conv2d: +def conv3x3(in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, diff --git a/efold/models/efold.py b/efold/models/efold.py index cdda582..be9ca9e 100644 --- a/efold/models/efold.py +++ b/efold/models/efold.py @@ -1,24 +1,15 @@ import numpy as np import torch from torch import nn, Tensor -import os -import sys from contextlib import ExitStack - -import typing as T from einops import rearrange import torch.nn.functional as F -import numpy as np -from ..core.batch import Batch -from ..core.model import Model +from collections import defaultdict -from collections import defaultdict +from efold.core import batch, model -dir_name = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(dir_name, "..")) - -class eFold(Model): +class eFold(model.Model): def __init__( self, ntoken: int, @@ -34,10 +25,8 @@ def __init__( optimizer_fn=torch.optim.Adam, **kwargs, ): - self.save_hyperparameters(ignore=['loss_fn']) - super().__init__( - lr=lr, loss_fn=loss_fn, optimizer_fn=optimizer_fn, **kwargs - ) + self.save_hyperparameters(ignore=["loss_fn"]) + super().__init__(lr=lr, loss_fn=loss_fn, optimizer_fn=optimizer_fn, **kwargs) self.model_type = "eFold" self.data_type_output = ["structure"] @@ -86,17 +75,15 @@ def __init__( kernel_size=3, dropout=dropout, ), - ResLayer( - dim_in=d_cnn // 2, dim_out=1, n_blocks=4, kernel_size=3, dropout=dropout - ), + ResLayer(dim_in=d_cnn // 2, dim_out=1, n_blocks=4, kernel_size=3, dropout=dropout), ) - def forward(self, batch: Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> Tensor: # Encoding of RNA sequence src = batch.get("sequence") - + s = self.encoder(src) # (N, L, d_model) - z = self.encoder_adapter(self.seq2map(src)).permute(0, 2, 3, 1) # (N, L, L, d_model) + z = self.encoder_adapter(self.seq2map(src)).permute(0, 2, 3, 1) # (N, L, L, d_model) # z = self.activ(self.encoder_adapter(s)) # (N, L, c_z / 2) # # Outer concatenation @@ -106,63 +93,75 @@ def forward(self, batch: Batch) -> Tensor: s, z = self.eFold(s, z) structure = self.structure_adapter(z) # (N, L, L, d_cnn) - structure = self.output_structure(structure.permute(0, 3, 1, 2)).squeeze( - 1 - ) # (N, L, L) + structure = self.output_structure(structure.permute(0, 3, 1, 2)).squeeze(1) # (N, L, L) return { # "dms": self.output_net_DMS(s).squeeze(axis=2), # "shape": self.output_net_SHAPE(s).squeeze(axis=2), - "structure": (structure + structure.permute(0, 2, 1)) - / 2, + "structure": (structure + structure.permute(0, 2, 1)) / 2, } - - def seq2map(self, seq_int): + def seq2map(self, seq_int): def int2seq(seq): # return ''.join(['XAUCG'[d] for d in seq]) - return ''.join(['XACGU'[d] for d in seq]) + return "".join(["XACGU"[d] for d in seq]) # take integer encoded sequence and return last channel of embedding (pairing energy) def creatmat(data, device=None): - with torch.no_grad(): data = int2seq(data) - paired = defaultdict(float, {'AU':2., 'UA':2., 'GC':3., 'CG':3., 'UG':0.8, 'GU':0.8}) + paired = defaultdict( + float, {"AU": 2.0, "UA": 2.0, "GC": 3.0, "CG": 3.0, "UG": 0.8, "GU": 0.8} + ) - mat = torch.tensor([[paired[x+y] for y in data] for x in data]).to(device) + mat = torch.tensor([[paired[x + y] for y in data] for x in data]).to(device) n = len(data) - i, j = torch.meshgrid(torch.arange(n).to(device), torch.arange(n).to(device), indexing='ij') + i, j = torch.meshgrid( + torch.arange(n).to(device), torch.arange(n).to(device), indexing="ij" + ) t = torch.arange(30).to(device) - m1 = torch.where((i[:, :, None] - t >= 0) & (j[:, :, None] + t < n), mat[torch.clamp(i[:,:,None]-t, 0, n-1), torch.clamp(j[:,:,None]+t, 0, n-1)], 0) - m1 *= torch.exp(-0.5*t*t) + m1 = torch.where( + (i[:, :, None] - t >= 0) & (j[:, :, None] + t < n), + mat[ + torch.clamp(i[:, :, None] - t, 0, n - 1), + torch.clamp(j[:, :, None] + t, 0, n - 1), + ], + 0, + ) + m1 *= torch.exp(-0.5 * t * t) m1_0pad = torch.nn.functional.pad(m1, (0, 1)) - first0 = torch.argmax((m1_0pad==0).to(int), dim=2) - to0indices = t[None,None,:]>first0[:,:,None] + first0 = torch.argmax((m1_0pad == 0).to(int), dim=2) + to0indices = t[None, None, :] > first0[:, :, None] m1[to0indices] = 0 m1 = m1.sum(dim=2) t = torch.arange(1, 30).to(device) - m2 = torch.where((i[:, :, None] + t < n) & (j[:, :, None] - t >= 0), mat[torch.clamp(i[:,:,None]+t, 0, n-1), torch.clamp(j[:,:,None]-t, 0, n-1)], 0) - m2 *= torch.exp(-0.5*t*t) + m2 = torch.where( + (i[:, :, None] + t < n) & (j[:, :, None] - t >= 0), + mat[ + torch.clamp(i[:, :, None] + t, 0, n - 1), + torch.clamp(j[:, :, None] - t, 0, n - 1), + ], + 0, + ) + m2 *= torch.exp(-0.5 * t * t) m2_0pad = torch.nn.functional.pad(m2, (0, 1)) - first0 = torch.argmax((m2_0pad==0).to(int), dim=2) - to0indices = torch.arange(29).to(device)[None,None,:]>first0[:,:,None] + first0 = torch.argmax((m2_0pad == 0).to(int), dim=2) + to0indices = torch.arange(29).to(device)[None, None, :] > first0[:, :, None] m2[to0indices] = 0 m2 = m2.sum(dim=2) - m2[m1==0] = 0 + m2[m1 == 0] = 0 - return (m1+m2).to(self.device) + return (m1 + m2).to(self.device) # Assemble all data full_map = [] one_hot_embed = torch.zeros((5, 4), device=self.device) one_hot_embed[1:] = torch.eye(4) for seq in seq_int: - seq_hot = one_hot_embed[seq].type(torch.long) pair_map = torch.kron(seq_hot, seq_hot).reshape(len(seq), len(seq), 16) @@ -170,7 +169,6 @@ def creatmat(data, device=None): full_map.append(torch.cat((pair_map, energy_map.unsqueeze(-1)), dim=-1)) - return torch.stack(full_map).permute(0, 3, 1, 2).contiguous() @@ -213,9 +211,7 @@ def __init__( self.pos = PositionalEncoding(self.c_s, dropout) self.ln = nn.LayerNorm(self.c_s, eps=1e-12, elementwise_affine=True) - self.resNet = ResLayer( - dim_in=c_z, dim_out=c_z, n_blocks=2, kernel_size=3, dropout=dropout - ) + self.resNet = ResLayer(dim_in=c_z, dim_out=c_z, n_blocks=2, kernel_size=3, dropout=dropout) # self.tri_mul_out = TriangleMultiplicationOutgoing( # c_z, @@ -330,9 +326,7 @@ def forward(self, sequence_state, pairwise_state): # Update pairwise state pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) - pairwise_state = self.resNet(pairwise_state.permute(0, 3, 1, 2)).permute( - 0, 2, 3, 1 - ) + pairwise_state = self.resNet(pairwise_state.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # # Axial attention # pairwise_state = pairwise_state + self.row_drop( @@ -635,9 +629,7 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): self.res_blocks = nn.Sequential(*self.res_layers) # Adapter to change depth - self.conv_output = nn.Conv2d( - dim_in, dim_out, kernel_size=7, padding=3, bias=True - ) + self.conv_output = nn.Conv2d(dim_in, dim_out, kernel_size=7, padding=3, bias=True) def forward(self, x: Tensor) -> Tensor: x = self.res_blocks(x) @@ -662,14 +654,10 @@ def __init__( self.bn1 = nn.BatchNorm2d(inplanes) self.relu1 = nn.ReLU(inplace=True) - self.conv1 = conv3x3( - inplanes, planes, dilation=dilation1, kernel_size=kernel_size - ) + self.conv1 = conv3x3(inplanes, planes, dilation=dilation1, kernel_size=kernel_size) self.dropout = nn.Dropout(p=dropout) self.relu2 = nn.ReLU(inplace=True) - self.conv2 = conv3x3( - planes, planes, dilation=dilation2, kernel_size=kernel_size - ) + self.conv2 = conv3x3(planes, planes, dilation=dilation2, kernel_size=kernel_size) def forward(self, x: Tensor) -> Tensor: identity = x @@ -686,9 +674,7 @@ def forward(self, x: Tensor) -> Tensor: return out -def conv3x3( - in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3 -) -> nn.Conv2d: +def conv3x3(in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, @@ -754,9 +740,7 @@ def __init__( self.value = nn.Linear(num_heads * head_size, num_heads * head_size, bias=False) ########################### - self.projection_kernel = nn.Parameter( - torch.rand(num_heads, head_size, output_size) * 2 - 1 - ) + self.projection_kernel = nn.Parameter(torch.rand(num_heads, head_size, output_size) * 2 - 1) if use_projection_bias: self.projection_bias = nn.Parameter(torch.rand(output_size) * 2 - 1) @@ -786,9 +770,7 @@ def call_qkv(self, query, key, value, training=False): return query, key, value - def call_attention( - self, query, key, value, logits, bias=None, training=False, mask=None - ): + def call_attention(self, query, key, value, logits, bias=None, training=False, mask=None): # Mask = attention mask with shape [B, Tquery, Tkey] with 1 for positions we want to attend, 0 for masked if mask is not None: if len(mask.size()) < 2: @@ -822,15 +804,11 @@ def call_attention( attn_coef_dropout = self.dropout(attn_coef) # Attention * value - multihead_output = torch.einsum( - "...HNM,...MHI->...NHI", attn_coef_dropout, value - ) + multihead_output = torch.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value) # Run the outputs through another linear projection layer. Recombining heads # is automatically done. - output = torch.einsum( - "...NHI,HIO->...NO", multihead_output, self.projection_kernel - ) + output = torch.einsum("...NHI,HIO->...NO", multihead_output, self.projection_kernel) if self.projection_bias is not None: output += self.projection_bias diff --git a/efold/models/factory.py b/efold/models/factory.py index ed60e99..cc62c7e 100644 --- a/efold/models/factory.py +++ b/efold/models/factory.py @@ -1,19 +1,15 @@ -from .transformer import Transformer -from .efold import eFold -from .cnn import CNN -from .ribonanza import Ribonanza -from .unet import U_Net +from efold.models import cnn, efold, ribonanza, transformer, unet def create_model(model: str, *args, **kwargs): if model == "transformer": - return Transformer(*args, **kwargs) + return transformer.Transformer(*args, **kwargs) if model == "efold": - return eFold(*args, **kwargs) + return efold.eFold(*args, **kwargs) if model == "cnn": - return CNN(*args, **kwargs) + return cnn.CNN(*args, **kwargs) if model == "unet": - return U_Net(*args, **kwargs) + return unet.U_Net(*args, **kwargs) if model == "ribonanza": - return Ribonanza(*args, **kwargs) + return ribonanza.Ribonanza(*args, **kwargs) raise ValueError(f"Unknown model: {model}") diff --git a/efold/models/ribonanza.py b/efold/models/ribonanza.py index 5909dec..bae485d 100644 --- a/efold/models/ribonanza.py +++ b/efold/models/ribonanza.py @@ -1,9 +1,10 @@ -from torch import nn, tensor import torch -from ..config import device, seq2int, START_TOKEN, END_TOKEN, PADDING_TOKEN -from ..core.model import Model +from torch import nn from torch.nn import init +from efold import settings +from efold.core import model + global_gain = 0.1 @@ -168,11 +169,11 @@ def sequence_batch(batch): out.append( torch.concat( [ - tensor([START_TOKEN], dtype=torch.long).to(device), + torch.tensor([settings.START_TOKEN], dtype=torch.long).to(settings.device), sequence[:length], - tensor([END_TOKEN], dtype=torch.long).to(device), - tensor([PADDING_TOKEN] * (L - length), dtype=torch.long).to( - device + torch.tensor([settings.END_TOKEN], dtype=torch.long).to(settings.device), + torch.tensor([settings.PADDING_TOKEN] * (L - length), dtype=torch.long).to( + settings.device ), ], ) @@ -182,9 +183,9 @@ def sequence_batch(batch): def structure_batch(batch): structure = batch.get("structure") batch_size, L, _ = structure.shape - embedded_matrix = torch.zeros( - (batch_size, L + 2, L + 2), dtype=torch.float32 - ).to(device) + embedded_matrix = torch.zeros((batch_size, L + 2, L + 2), dtype=torch.float32).to( + settings.device + ) embedded_matrix[:, 1:-1, 1:-1] = structure return embedded_matrix @@ -205,7 +206,7 @@ def forward(self, sequence, structure): return sequence, structure -class Ribonanza(Model): +class Ribonanza(model.Model): ntokens = 7 data_type = ["dms", "shape"] diff --git a/efold/models/transformer.py b/efold/models/transformer.py index d2601ae..9f762b6 100644 --- a/efold/models/transformer.py +++ b/efold/models/transformer.py @@ -1,17 +1,12 @@ +import numpy as np import torch from torch import nn, Tensor from torch.nn import TransformerEncoderLayer -import numpy as np -import os -import sys -from ..core.model import Model -from ..core.batch import Batch -dir_name = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(dir_name, "..")) +from efold.core import batch, model -class Transformer(Model): +class Transformer(model.Model): def __init__( self, ntoken: int, @@ -82,9 +77,7 @@ def __init__( assert c_z % 4 == 0, "c_z must be divisible by 4" assert c_z >= 8, "c_z must be greater than 8" self.output_net_structure = nn.Sequential( - ResLayer( - n_blocks=4, dim_in=c_z, dim_out=c_z // 2, kernel_size=3, dropout=dropout - ), + ResLayer(n_blocks=4, dim_in=c_z, dim_out=c_z // 2, kernel_size=3, dropout=dropout), ResLayer( n_blocks=4, dim_in=c_z // 2, @@ -92,12 +85,10 @@ def __init__( kernel_size=3, dropout=dropout, ), - ResLayer( - n_blocks=4, dim_in=c_z // 4, dim_out=1, kernel_size=3, dropout=dropout - ), + ResLayer(n_blocks=4, dim_in=c_z // 4, dim_out=1, kernel_size=3, dropout=dropout), ) - def forward(self, batch: Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> Tensor: """ Args: src: Tensor, shape [seq_len, batch_size] @@ -120,14 +111,10 @@ def forward(self, batch: Batch) -> Tensor: # Outer concatenation src = self.activ(self.encoder_adapter(src)) matrix = src.unsqueeze(1).repeat(1, src.shape[1], 1, 1) # (N, d_cnn/2, L, L) - matrix = torch.cat( - (matrix, matrix.permute(0, 2, 1, 3)), dim=-1 - ) # (N, d_cnn, L, L) + matrix = torch.cat((matrix, matrix.permute(0, 2, 1, 3)), dim=-1) # (N, d_cnn, L, L) # Resnet layers - pair_prob = self.output_net_structure(matrix.permute(0, 3, 1, 2)).squeeze( - 1 - ) # (N, L, L) + pair_prob = self.output_net_structure(matrix.permute(0, 3, 1, 2)).squeeze(1) # (N, L, L) # Symmetrize structure = (pair_prob + pair_prob.permute(0, 2, 1)) / 2 # (N, L, L) @@ -183,9 +170,7 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): self.res_blocks = nn.Sequential(*self.res_layers) # Adapter to change depth - self.conv_output = nn.Conv2d( - dim_in, dim_out, kernel_size=7, padding=3, bias=True - ) + self.conv_output = nn.Conv2d(dim_in, dim_out, kernel_size=7, padding=3, bias=True) def forward(self, x: Tensor) -> Tensor: x = self.res_blocks(x) @@ -210,9 +195,7 @@ def __init__( self.bn1 = nn.BatchNorm2d(inplanes) self.relu1 = nn.ReLU(inplace=True) - self.conv1 = conv3x3( - inplanes, planes, dilation=dilation1, kernel_size=kernel_size - ) + self.conv1 = conv3x3(inplanes, planes, dilation=dilation1, kernel_size=kernel_size) self.dropout = nn.Dropout(p=dropout) self.relu2 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d( @@ -239,9 +222,7 @@ def forward(self, x: Tensor) -> Tensor: return out -def conv3x3( - in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3 -) -> nn.Conv2d: +def conv3x3(in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, diff --git a/efold/models/unet.py b/efold/models/unet.py index 4f2bd8f..50f24e3 100644 --- a/efold/models/unet.py +++ b/efold/models/unet.py @@ -1,62 +1,55 @@ import torch from torch import nn, Tensor import torch.nn.functional as F -from torch.nn import init +from collections import defaultdict -from ..core.model import Model -from ..core.batch import Batch - -from ..config import int2seq - -import os, sys - -from collections import defaultdict - -dir_name = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(dir_name, "..")) +from efold import settings +from efold.core import batch, model CH_FOLD2 = 1 -class U_Net(Model): - def __init__(self, img_ch=3, output_ch=1, lr: float = 1e-5, optimizer_fn=torch.optim.Adam, **kwargs): - +class U_Net(model.Model): + def __init__( + self, img_ch=3, output_ch=1, lr: float = 1e-5, optimizer_fn=torch.optim.Adam, **kwargs + ): super().__init__(lr=lr, optimizer_fn=optimizer_fn, **kwargs) self.model_type = "Unet" self.data_type_output = ["structure"] - - self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) - self.Conv1 = conv_block(ch_in=img_ch,ch_out=int(32*CH_FOLD2)) - self.Conv2 = conv_block(ch_in=int(32*CH_FOLD2),ch_out=int(64*CH_FOLD2)) - self.Conv3 = conv_block(ch_in=int(64*CH_FOLD2),ch_out=int(128*CH_FOLD2)) - self.Conv4 = conv_block(ch_in=int(128*CH_FOLD2),ch_out=int(256*CH_FOLD2)) - self.Conv5 = conv_block(ch_in=int(256*CH_FOLD2),ch_out=int(512*CH_FOLD2)) + self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + + self.Conv1 = conv_block(ch_in=img_ch, ch_out=int(32 * CH_FOLD2)) + self.Conv2 = conv_block(ch_in=int(32 * CH_FOLD2), ch_out=int(64 * CH_FOLD2)) + self.Conv3 = conv_block(ch_in=int(64 * CH_FOLD2), ch_out=int(128 * CH_FOLD2)) + self.Conv4 = conv_block(ch_in=int(128 * CH_FOLD2), ch_out=int(256 * CH_FOLD2)) + self.Conv5 = conv_block(ch_in=int(256 * CH_FOLD2), ch_out=int(512 * CH_FOLD2)) - self.Up5 = up_conv(ch_in=int(512*CH_FOLD2),ch_out=int(256*CH_FOLD2)) - self.Up_conv5 = conv_block(ch_in=int(512*CH_FOLD2), ch_out=int(256*CH_FOLD2)) + self.Up5 = up_conv(ch_in=int(512 * CH_FOLD2), ch_out=int(256 * CH_FOLD2)) + self.Up_conv5 = conv_block(ch_in=int(512 * CH_FOLD2), ch_out=int(256 * CH_FOLD2)) - self.Up4 = up_conv(ch_in=int(256*CH_FOLD2),ch_out=int(128*CH_FOLD2)) - self.Up_conv4 = conv_block(ch_in=int(256*CH_FOLD2), ch_out=int(128*CH_FOLD2)) - - self.Up3 = up_conv(ch_in=int(128*CH_FOLD2),ch_out=int(64*CH_FOLD2)) - self.Up_conv3 = conv_block(ch_in=int(128*CH_FOLD2), ch_out=int(64*CH_FOLD2)) - - self.Up2 = up_conv(ch_in=int(64*CH_FOLD2),ch_out=int(32*CH_FOLD2)) - self.Up_conv2 = conv_block(ch_in=int(64*CH_FOLD2), ch_out=int(32*CH_FOLD2)) + self.Up4 = up_conv(ch_in=int(256 * CH_FOLD2), ch_out=int(128 * CH_FOLD2)) + self.Up_conv4 = conv_block(ch_in=int(256 * CH_FOLD2), ch_out=int(128 * CH_FOLD2)) - self.Conv_1x1 = nn.Conv2d(int(32*CH_FOLD2),output_ch,kernel_size=1,stride=1,padding=0) + self.Up3 = up_conv(ch_in=int(128 * CH_FOLD2), ch_out=int(64 * CH_FOLD2)) + self.Up_conv3 = conv_block(ch_in=int(128 * CH_FOLD2), ch_out=int(64 * CH_FOLD2)) + self.Up2 = up_conv(ch_in=int(64 * CH_FOLD2), ch_out=int(32 * CH_FOLD2)) + self.Up_conv2 = conv_block(ch_in=int(64 * CH_FOLD2), ch_out=int(32 * CH_FOLD2)) - def forward(self, batch: Batch) -> Tensor: + self.Conv_1x1 = nn.Conv2d(int(32 * CH_FOLD2), output_ch, kernel_size=1, stride=1, padding=0) + def forward(self, batch: batch.Batch) -> Tensor: src = batch.get("sequence") padd_multiple = 32 pad_len = 0 - if src.shape[1]%padd_multiple: - pad_len = (padd_multiple-src.shape[1]%padd_multiple) - src = torch.cat( (src, torch.zeros((src.shape[0], pad_len), device=self.device, dtype=torch.long) ), dim=-1) + if src.shape[1] % padd_multiple: + pad_len = padd_multiple - src.shape[1] % padd_multiple + src = torch.cat( + (src, torch.zeros((src.shape[0], pad_len), device=self.device, dtype=torch.long)), + dim=-1, + ) # def get_cut_len(data_len,set_len): # l = data_len @@ -78,7 +71,7 @@ def forward(self, batch: Batch) -> Tensor: x2 = self.Maxpool(x1) x2 = self.Conv2(x2) - + x3 = self.Maxpool(x2) x3 = self.Conv3(x3) @@ -90,20 +83,20 @@ def forward(self, batch: Batch) -> Tensor: # decoding + concat path d5 = self.Up5(x5) - d5 = torch.cat((x4,d5),dim=1) - + d5 = torch.cat((x4, d5), dim=1) + d5 = self.Up_conv5(d5) - + d4 = self.Up4(d5) - d4 = torch.cat((x3,d4),dim=1) + d4 = torch.cat((x3, d4), dim=1) d4 = self.Up_conv4(d4) d3 = self.Up3(d4) - d3 = torch.cat((x2,d3),dim=1) + d3 = torch.cat((x2, d3), dim=1) d3 = self.Up_conv3(d3) d2 = self.Up2(d3) - d2 = torch.cat((x1,d2),dim=1) + d2 = torch.cat((x1, d2), dim=1) d2 = self.Up_conv2(d2) d1 = self.Conv_1x1(d2) @@ -111,53 +104,65 @@ def forward(self, batch: Batch) -> Tensor: structure = torch.transpose(d1, -1, -2) * d1 - return { - "structure": structure[:, :src.shape[1]-pad_len, :src.shape[1]-pad_len] - } - + return {"structure": structure[:, : src.shape[1] - pad_len, : src.shape[1] - pad_len]} def seq2map(self, seq_int): - # take integer encoded sequence and return last channel of embedding (pairing energy) def creatmat(data, device=None): - with torch.no_grad(): - data = ''.join([int2seq[d] for d in data.tolist()]) - paired = defaultdict(float, {'AU':2., 'UA':2., 'GC':3., 'CG':3., 'UG':0.8, 'GU':0.8}) + data = "".join([settings.int2seq[d] for d in data.tolist()]) + paired = defaultdict( + float, {"AU": 2.0, "UA": 2.0, "GC": 3.0, "CG": 3.0, "UG": 0.8, "GU": 0.8} + ) - mat = torch.tensor([[paired[x+y] for y in data] for x in data]).to(device) + mat = torch.tensor([[paired[x + y] for y in data] for x in data]).to(device) n = len(data) - i, j = torch.meshgrid(torch.arange(n).to(device), torch.arange(n).to(device), indexing='ij') + i, j = torch.meshgrid( + torch.arange(n).to(device), torch.arange(n).to(device), indexing="ij" + ) t = torch.arange(30).to(device) - m1 = torch.where((i[:, :, None] - t >= 0) & (j[:, :, None] + t < n), mat[torch.clamp(i[:,:,None]-t, 0, n-1), torch.clamp(j[:,:,None]+t, 0, n-1)], 0) - m1 *= torch.exp(-0.5*t*t) + m1 = torch.where( + (i[:, :, None] - t >= 0) & (j[:, :, None] + t < n), + mat[ + torch.clamp(i[:, :, None] - t, 0, n - 1), + torch.clamp(j[:, :, None] + t, 0, n - 1), + ], + 0, + ) + m1 *= torch.exp(-0.5 * t * t) m1_0pad = torch.nn.functional.pad(m1, (0, 1)) - first0 = torch.argmax((m1_0pad==0).to(int), dim=2) - to0indices = t[None,None,:]>first0[:,:,None] + first0 = torch.argmax((m1_0pad == 0).to(int), dim=2) + to0indices = t[None, None, :] > first0[:, :, None] m1[to0indices] = 0 m1 = m1.sum(dim=2) t = torch.arange(1, 30).to(device) - m2 = torch.where((i[:, :, None] + t < n) & (j[:, :, None] - t >= 0), mat[torch.clamp(i[:,:,None]+t, 0, n-1), torch.clamp(j[:,:,None]-t, 0, n-1)], 0) - m2 *= torch.exp(-0.5*t*t) + m2 = torch.where( + (i[:, :, None] + t < n) & (j[:, :, None] - t >= 0), + mat[ + torch.clamp(i[:, :, None] + t, 0, n - 1), + torch.clamp(j[:, :, None] - t, 0, n - 1), + ], + 0, + ) + m2 *= torch.exp(-0.5 * t * t) m2_0pad = torch.nn.functional.pad(m2, (0, 1)) - first0 = torch.argmax((m2_0pad==0).to(int), dim=2) - to0indices = torch.arange(29).to(device)[None,None,:]>first0[:,:,None] + first0 = torch.argmax((m2_0pad == 0).to(int), dim=2) + to0indices = torch.arange(29).to(device)[None, None, :] > first0[:, :, None] m2[to0indices] = 0 m2 = m2.sum(dim=2) - m2[m1==0] = 0 + m2[m1 == 0] = 0 - return (m1+m2).to(self.device) + return (m1 + m2).to(self.device) # Assemble all data full_map = [] one_hot_embed = torch.zeros((5, 4), device=self.device) one_hot_embed[1:] = torch.eye(4) for seq in seq_int: - seq_hot = one_hot_embed[seq].type(torch.long) pair_map = torch.kron(seq_hot, seq_hot).reshape(len(seq), len(seq), 16) @@ -165,46 +170,41 @@ def creatmat(data, device=None): full_map.append(torch.cat((pair_map, energy_map.unsqueeze(-1)), dim=-1)) - return torch.stack(full_map).permute(0, 3, 1, 2).contiguous() - class conv_block(nn.Module): - def __init__(self,ch_in,ch_out): - super(conv_block,self).__init__() + def __init__(self, ch_in, ch_out): + super(conv_block, self).__init__() self.conv = nn.Sequential( - nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), - nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) - - def forward(self,x): + def forward(self, x): x = self.conv(x) return x + class up_conv(nn.Module): - def __init__(self,ch_in,ch_out): - super(up_conv,self).__init__() + def __init__(self, ch_in, ch_out): + super(up_conv, self).__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2), - nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), - nn.BatchNorm2d(ch_out), - nn.ReLU(inplace=True) + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), ) - def forward(self,x): + def forward(self, x): x = self.up(x) return x - - - # class U_Net_FP(nn.Module): # def __init__(self,img_ch=17,output_ch=1): # super(U_Net_FP, self).__init__() @@ -277,4 +277,4 @@ def forward(self,x): # d1 = self.Conv_1x1(d2) # d1 = d1.squeeze(1) -# return torch.transpose(d1, -1, -2) * d1 \ No newline at end of file +# return torch.transpose(d1, -1, -2) * d1 diff --git a/efold/settings.py b/efold/settings.py new file mode 100644 index 0000000..5b2837d --- /dev/null +++ b/efold/settings.py @@ -0,0 +1,56 @@ +from pathlib import Path +from typing import Any + +import torch +import yaml +from torch import backends, cuda, float32 + +_settings_path = Path(__file__).parent / "settings.yaml" + + +def _load_settings() -> dict[str, Any]: + """ + Load settings from the YAML configuration file. + + :return: Dictionary containing all settings + """ + with open(_settings_path, "r") as f: + return yaml.safe_load(f) + + +_config = _load_settings() + +seq2int = _config["seq2int"] +int2seq = {v: k for k, v in seq2int.items()} + +START_TOKEN = _config["start_token"] +END_TOKEN = _config["end_token"] +PADDING_TOKEN = seq2int[_config["padding_token_key"]] + +DEFAULT_FORMAT = float32 +torch.set_default_dtype(DEFAULT_FORMAT) +UKN = _config["unknown_value"] +VAL_GU = _config["val_gu"] + +device = "cuda" if cuda.is_available() else "cpu" + +TEST_SETS = _config["test_sets"] +TEST_SETS_NAMES = [i for j in TEST_SETS.values() for i in j] +DATA_TYPES_TEST_SETS = [k for k, v in TEST_SETS.items() for i in v] + +DATA_TYPES = _config["data_types"] + +_dtype_mapping = { + "float32": torch.float32, + "int32": torch.int32, +} + +DATA_TYPES_FORMAT = {k: _dtype_mapping[v] for k, v in _config["data_types_format"].items()} + +REFERENCE_METRIC = _config["reference_metric"] +REF_METRIC_SIGN = _config["ref_metric_sign"] +POSSIBLE_METRICS = _config["possible_metrics"] + +DTYPE_PER_DATA_TYPE = DATA_TYPES_FORMAT + +torch.set_default_dtype(torch.float32) diff --git a/efold/settings.yaml b/efold/settings.yaml new file mode 100644 index 0000000..07436e5 --- /dev/null +++ b/efold/settings.yaml @@ -0,0 +1,66 @@ +# Token mappings +seq2int: + X: 0 + A: 1 + C: 2 + G: 3 + U: 4 + +# Special tokens +start_token: null +end_token: null +padding_token_key: X + +# PyTorch configuration +default_format: float32 +unknown_value: -1000.0 +val_gu: 0.095 + +# Test sets +test_sets: + structure: + - PDB + - archiveII + - lncRNA_nonFiltered + - viral_fragments + sequence: [] + dms: [] + shape: [] + +# Data types +data_types: + - structure + - dms + - shape + +# Data type formats +data_types_format: + structure: int32 + dms: float32 + shape: float32 + +# Reference metrics +reference_metric: + structure: f1 + dms: mae + shape: mae + +# Metric signs (1 for higher is better, -1 for lower is better) +ref_metric_sign: + structure: 1 + dms: -1 + shape: -1 + +# Possible metrics per data type +possible_metrics: + structure: + - f1 + dms: + - mae + - r2 + - pearson + shape: + - mae + - r2 + - pearson + diff --git a/efold/util/__init__.py b/efold/util/__init__.py index f64f073..e69de29 100644 --- a/efold/util/__init__.py +++ b/efold/util/__init__.py @@ -1,21 +0,0 @@ -from torch.optim import Adam -from torch import nn - -str2fun = { - "adam": Adam, - "Adam": Adam, - "mse_loss": nn.functional.mse_loss, - "l1_loss": nn.functional.l1_loss, -} - - -def unzip(f): - def wrapper(*args, **kwargs): - out = f(*args, **kwargs) - if not hasattr(out, "__iter__"): - return out - if len(out) == 1: - return out[0] - return out - - return wrapper diff --git a/efold/util/format_conversion.py b/efold/util/format_conversion.py index f1f46de..27d2c83 100644 --- a/efold/util/format_conversion.py +++ b/efold/util/format_conversion.py @@ -1,11 +1,9 @@ # this code wsa taken from arnie_utils.py -def convert_bp_list_to_dotbracket(bp_list, seq_len): - - # convert the bp list from 1-indexed to 0-indexed - bp_list = [(b-1, c-1) for b, c in bp_list] - +def convert_bp_list_to_dotbracket(bp_list: list[tuple[int, int]], seq_len: int) -> str: + bp_list = [(b - 1, c - 1) for b, c in bp_list] + db = "." * seq_len # group into bps that are not intertwined and can use same brackets! groups = _group_into_non_conflicting_bp(bp_list) @@ -13,8 +11,9 @@ def convert_bp_list_to_dotbracket(bp_list, seq_len): # all bp that are not intertwined get (), but all others are # groups to be nonconflicting and then asigned (), [], {}, <> by group chars_set = [("(", ")"), ("(", ")"), ("[", "]"), ("{", "}"), ("<", ">")] - alphabet = [(chr(lower), chr(upper)) - for upper, lower in zip(list(range(65, 91)), list(range(97, 123)))] + alphabet = [ + (chr(lower), chr(upper)) for upper, lower in zip(list(range(65, 91)), list(range(97, 123))) + ] chars_set.extend(alphabet) if len(groups) > len(chars_set): @@ -23,25 +22,23 @@ def convert_bp_list_to_dotbracket(bp_list, seq_len): for group, chars in zip(groups, chars_set): for bp in group: - db = db[:bp[0]] + chars[0] + \ - db[bp[0] + 1:bp[1]] + chars[1] + db[bp[1] + 1:] + db = db[: bp[0]] + chars[0] + db[bp[0] + 1 : bp[1]] + chars[1] + db[bp[1] + 1 :] return db -def _group_into_non_conflicting_bp(bp_list): - ''' given a bp_list, group basepairs into groups that do not conflict +def _group_into_non_conflicting_bp(bp_list: list[tuple[int, int]]) -> list[list[tuple[int, int]]]: + """given a bp_list, group basepairs into groups that do not conflict Args bp_list: list of base_pairs Returns: groups of baspairs that are not intertwined - ''' + """ conflict_list = _get_list_bp_conflicts(bp_list) non_redudant_bp_list = _get_non_redudant_bp_list(conflict_list) - bp_with_no_conflict = [ - bp for bp in bp_list if bp not in non_redudant_bp_list] + bp_with_no_conflict = [bp for bp in bp_list if bp not in non_redudant_bp_list] groups = [bp_with_no_conflict] while non_redudant_bp_list != []: current_bp = non_redudant_bp_list[0] @@ -51,8 +48,7 @@ def _group_into_non_conflicting_bp(bp_list): current_bp_conflicts.append(conflict[1]) elif current_bp == conflict[1]: current_bp_conflicts.append(conflict[0]) - max_group = [ - bp for bp in non_redudant_bp_list if bp not in current_bp_conflicts] + max_group = [bp for bp in non_redudant_bp_list if bp not in current_bp_conflicts] to_remove = [] for i, bpA in enumerate(max_group): for bpB in max_group[i:]: @@ -62,19 +58,22 @@ def _group_into_non_conflicting_bp(bp_list): group = [bp for bp in max_group if bp not in to_remove] groups.append(group) non_redudant_bp_list = current_bp_conflicts - conflict_list = [conflict for conflict in conflict_list if conflict[0] - not in group and conflict[1] not in group] + conflict_list = [ + conflict + for conflict in conflict_list + if conflict[0] not in group and conflict[1] not in group + ] return groups -def _get_non_redudant_bp_list(conflict_list): - ''' given a conflict list get the list of nonredundant basepairs this list has +def _get_non_redudant_bp_list(conflict_list: list) -> list: + """given a conflict list get the list of nonredundant basepairs this list has Args: conflict_list: list of pairs of base_pairs that are intertwined basepairs returns: list of basepairs in conflict list without repeats - ''' + """ non_redudant_bp_list = [] for conflict in conflict_list: if conflict[0] not in non_redudant_bp_list: @@ -84,20 +83,19 @@ def _get_non_redudant_bp_list(conflict_list): return non_redudant_bp_list - -def _get_list_bp_conflicts(bp_list): - '''given a bp_list gives the list of conflicts bp-s which indicate PK structure +def _get_list_bp_conflicts(bp_list: list[tuple[int, int]]) -> list: + """given a bp_list gives the list of conflicts bp-s which indicate PK structure Args: bp_list: of list of base pairs where the base pairs are list of indeces of the bp in increasing order (bp[0] threshold).float()\n", "\n", - "\n", - " TP = torch.sum(pred_matrix*target_matrix)\n", + " TP = torch.sum(pred_matrix * target_matrix)\n", " PP = torch.sum(pred_matrix)\n", " P = torch.sum(target_matrix)\n", " sum_pair = PP + P\n", @@ -102,11 +104,8 @@ " if sum_pair == 0:\n", " return [1.0, 1.0, 1.0]\n", " else:\n", - " return [\n", - " (TP / PP).item(),\n", - " (TP / P).item(),\n", - " (2 * TP / sum_pair).item()\n", - " ]\n", + " return [(TP / PP).item(), (TP / P).item(), (2 * TP / sum_pair).item()]\n", + "\n", "\n", "def compute_confusion_matrix(label, pred):\n", " true_negatives = (1 - label) * (1 - pred)\n", @@ -114,7 +113,9 @@ " false_positives = (1 - label) * pred\n", " false_negatives = label * (1 - pred)\n", " confusion_matrix = true_positives + false_positives * 2 + false_negatives * 3\n", - " assert ((true_negatives == 1) == (confusion_matrix == 0)).all(), \"True negatives are not correctly computed\"\n", + " assert ((true_negatives == 1) == (confusion_matrix == 0)).all(), (\n", + " \"True negatives are not correctly computed\"\n", + " )\n", " return confusion_matrix" ] }, @@ -124,10 +125,14 @@ "metadata": {}, "outputs": [], "source": [ - "ground_truth['non_canonical'] = ground_truth.apply(lambda x: ratio_nonCanonical(x['sequence'], x['structure']), axis=1)\n", - "ground_truth['sharp_loops'] = ground_truth.apply(lambda x: ratio_sharpLoops(x['structure']), axis=1)\n", - "ground_truth[ground_truth['sharp_loops'] >0]\n", - "ground_truth['pairing_matrix'] = ground_truth.apply(lambda x: ListofPairs2pairMatrix(np.array(x['structure']), len(x['sequence'])), axis=1)" + "ground_truth[\"non_canonical\"] = ground_truth.apply(\n", + " lambda x: ratio_nonCanonical(x[\"sequence\"], x[\"structure\"]), axis=1\n", + ")\n", + "ground_truth[\"sharp_loops\"] = ground_truth.apply(lambda x: ratio_sharpLoops(x[\"structure\"]), axis=1)\n", + "ground_truth[ground_truth[\"sharp_loops\"] > 0]\n", + "ground_truth[\"pairing_matrix\"] = ground_truth.apply(\n", + " lambda x: ListofPairs2pairMatrix(np.array(x[\"structure\"]), len(x[\"sequence\"])), axis=1\n", + ")" ] }, { @@ -138,7 +143,7 @@ "source": [ "import time\n", "from tqdm import tqdm\n", - "from efold import inference\n", + "from efold.api.run import run\n", "\n", "eFold_processed = pd.DataFrame()\n", "\n", @@ -154,15 +159,17 @@ " dTs = []\n", "\n", " for idx, row in tqdm(ground_truth.iterrows(), total=len(ground_truth)):\n", - " true_structure = torch.tensor(row['structure'])\n", - " sequence = row['sequence']\n", + " true_structure = torch.tensor(row[\"structure\"])\n", + " sequence = row[\"sequence\"]\n", "\n", " t0 = time.time()\n", - " prediction = torch.tensor(inference(sequence, fmt='bp')[sequence])-1\n", + " prediction = torch.tensor(run.run(sequence, fmt=\"bp\")[sequence]) - 1\n", " dT = time.time() - t0\n", "\n", - " precision, recall, f1 = compute_f1(ListofPairs2pairMatrix(prediction, len(sequence)), \n", - " ListofPairs2pairMatrix(true_structure, len(sequence)))\n", + " precision, recall, f1 = compute_f1(\n", + " ListofPairs2pairMatrix(prediction, len(sequence)),\n", + " ListofPairs2pairMatrix(true_structure, len(sequence)),\n", + " )\n", "\n", " Precisions.append(precision)\n", " Recalls.append(recall)\n", @@ -170,12 +177,31 @@ " predictions.append(prediction)\n", " dTs.append(dT)\n", "\n", - " eFold_processed = pd.concat([eFold_processed, pd.DataFrame({'reference': ground_truth.index, 'sequence': ground_truth['sequence'],\n", - " 'threshold': threshold,\n", - " 'precision': Precisions, 'recall': Recalls, 'f1': F1s, 'dT': dTs,\n", - " 'structure': predictions})], axis=0)\n", + " eFold_processed = pd.concat(\n", + " [\n", + " eFold_processed,\n", + " pd.DataFrame(\n", + " {\n", + " \"reference\": ground_truth.index,\n", + " \"sequence\": ground_truth[\"sequence\"],\n", + " \"threshold\": threshold,\n", + " \"precision\": Precisions,\n", + " \"recall\": Recalls,\n", + " \"f1\": F1s,\n", + " \"dT\": dTs,\n", + " \"structure\": predictions,\n", + " }\n", + " ),\n", + " ],\n", + " axis=0,\n", + " )\n", "# Add dataset name\n", - "eFold_processed = eFold_processed.merge(ground_truth.reset_index().rename(columns={'index':'reference'})[['reference', 'sequence', 'dataset']], on=['reference', 'sequence'])\n", + "eFold_processed = eFold_processed.merge(\n", + " ground_truth.reset_index().rename(columns={\"index\": \"reference\"})[\n", + " [\"reference\", \"sequence\", \"dataset\"]\n", + " ],\n", + " on=[\"reference\", \"sequence\"],\n", + ")\n", "# eFold_processed['basePairs'] = eFold_processed['structure'].apply(lambda x: torch.unique(torch.sort(torch.stack(torch.where(x>0)).T, dim=1)[0], dim=0))" ] }, @@ -185,13 +211,21 @@ "metadata": {}, "outputs": [], "source": [ - "eFold_processed['non_canonical'] = eFold_processed.apply(lambda x: ratio_nonCanonical(x['sequence'], x['structure']), axis=1)\n", - "eFold_processed['sharp_loops'] = eFold_processed.apply(lambda x: ratio_sharpLoops(x['structure']), axis=1)\n", - "eFold_processed['pairing_matrix'] = eFold_processed.apply(lambda x: ListofPairs2pairMatrix(x['structure'], len(x['sequence'])), axis=1)\n", - "eFold_processed['multiPairs'] = eFold_processed['pairing_matrix'].apply(lambda x: (x.sum(axis=0) >1).sum().item()/len(x) )\n", - "eFold_processed['length'] = eFold_processed['sequence'].apply(len)\n", - "\n", - "eFold_processed.groupby('threshold')[['non_canonical', 'sharp_loops', 'multiPairs']].mean()" + "eFold_processed[\"non_canonical\"] = eFold_processed.apply(\n", + " lambda x: ratio_nonCanonical(x[\"sequence\"], x[\"structure\"]), axis=1\n", + ")\n", + "eFold_processed[\"sharp_loops\"] = eFold_processed.apply(\n", + " lambda x: ratio_sharpLoops(x[\"structure\"]), axis=1\n", + ")\n", + "eFold_processed[\"pairing_matrix\"] = eFold_processed.apply(\n", + " lambda x: ListofPairs2pairMatrix(x[\"structure\"], len(x[\"sequence\"])), axis=1\n", + ")\n", + "eFold_processed[\"multiPairs\"] = eFold_processed[\"pairing_matrix\"].apply(\n", + " lambda x: (x.sum(axis=0) > 1).sum().item() / len(x)\n", + ")\n", + "eFold_processed[\"length\"] = eFold_processed[\"sequence\"].apply(len)\n", + "\n", + "eFold_processed.groupby(\"threshold\")[[\"non_canonical\", \"sharp_loops\", \"multiPairs\"]].mean()" ] }, { @@ -201,10 +235,12 @@ "outputs": [], "source": [ "# Group the data by model and dataset and calculate the mean for each group\n", - "grouped = eFold_processed.groupby(['threshold', 'dataset']).mean(numeric_only=True).reset_index()\n", + "grouped = eFold_processed.groupby([\"threshold\", \"dataset\"]).mean(numeric_only=True).reset_index()\n", "\n", "# Pivot the table to create a multi-level column structure\n", - "pivot_df = pd.pivot_table(grouped, index='threshold', columns='dataset', values=['precision', 'recall', 'f1'])\n", + "pivot_df = pd.pivot_table(\n", + " grouped, index=\"threshold\", columns=\"dataset\", values=[\"precision\", \"recall\", \"f1\"]\n", + ")\n", "\n", "# Swap the level of the columns to have dataset as the top level and the metrics as the second level\n", "pivot_df = pivot_df.swaplevel(i=0, j=1, axis=1).sort_index(axis=1)\n", @@ -213,16 +249,21 @@ "# new_order = ['SimpleThreshold', 'HungarianAlgorithm', 'UFold_processing', 'OptimalProcessing']\n", "# pivot_df = pivot_df.reindex(new_order)\n", "\n", - "pivot_df = pivot_df.reindex(columns=pivot_df.columns.reindex(['precision', 'recall', 'f1'], level=1)[0])[['PDB', 'archiveII_blast', 'viral_fragments', 'lncRNA_nonFiltered']]\n", - "\n", - "pivot_df = pivot_df.style\\\n", - " .format(precision=3)\\\n", - " .highlight_max(axis=0, props=\"font-weight:bold;font-color:black;\")\\\n", - " .background_gradient(axis=1, vmin=-0.1, vmax=1, cmap=\"viridis\", text_color_threshold=0)\\\n", - " .set_properties(**{'text-align': 'center'})\\\n", - " .set_table_styles(\n", - " [{\"selector\": \"th\", \"props\": [('text-align', 'center')]},\n", - " ])\n", + "pivot_df = pivot_df.reindex(\n", + " columns=pivot_df.columns.reindex([\"precision\", \"recall\", \"f1\"], level=1)[0]\n", + ")[[\"PDB\", \"archiveII_blast\", \"viral_fragments\", \"lncRNA_nonFiltered\"]]\n", + "\n", + "pivot_df = (\n", + " pivot_df.style.format(precision=3)\n", + " .highlight_max(axis=0, props=\"font-weight:bold;font-color:black;\")\n", + " .background_gradient(axis=1, vmin=-0.1, vmax=1, cmap=\"viridis\", text_color_threshold=0)\n", + " .set_properties(**{\"text-align\": \"center\"})\n", + " .set_table_styles(\n", + " [\n", + " {\"selector\": \"th\", \"props\": [(\"text-align\", \"center\")]},\n", + " ]\n", + " )\n", + ")\n", "pivot_df" ] } diff --git a/tests/test_speed_eFold.py b/tests/test_speed_eFold.py index 6e9d9f2..4521c6e 100644 --- a/tests/test_speed_eFold.py +++ b/tests/test_speed_eFold.py @@ -1,8 +1,9 @@ import sys, os + file_dir = os.path.dirname(os.path.realpath(__file__)) -sys.path.append(os.path.join(file_dir, '../../eFold')) +sys.path.append(os.path.join(file_dir, "../../eFold")) -from efold import inference +from efold.api import run as run_module import pandas as pd import numpy as np from rouskinhf import get_dataset @@ -13,7 +14,7 @@ from rnastructure_wrapper import RNAstructure -Fold = RNAstructure(path='/root/RNAstructure/exe/') +Fold = RNAstructure(path="/root/RNAstructure/exe/") rnaStructure_dTs = [] efold_GPU_dTs = [] @@ -21,38 +22,42 @@ lengths = np.linspace(10, 500, 100).astype(int) for length in tqdm(lengths): - - sequence_random = ''.join(np.random.choice(['A', 'C', 'G', 'U'], length)) + sequence_random = "".join(np.random.choice(["A", "C", "G", "U"], length)) # eFold GPU t0 = time.time() - inference(sequence_random, fmt='bp') - dT = time.time()-t0 + run_module.run(sequence_random, fmt="bp") + dT = time.time() - t0 efold_GPU_dTs.append(dT) # RNAfold t0 = time.time() - inference(sequence_random, fmt='bp', device='cpu') - dT = time.time()-t0 + run_module.run(sequence_random, fmt="bp", device="cpu") + dT = time.time() - t0 efold_CPU_dTs.append(dT) # RNAstructure no MFE t0 = time.time() Fold.fold(sequence_random, mfe_only=False) - dT = time.time()-t0 + dT = time.time() - t0 rnaStructure_dTs.append(dT) import plotly.graph_objects as go fig = go.Figure() -fig.add_trace(go.Scatter(x=lengths, y=rnaStructure_dTs, mode='lines', name='RNAstructure')) -fig.add_trace(go.Scatter(x=lengths, y=efold_CPU_dTs, mode='lines', name='eFold (CPU)')) -fig.add_trace(go.Scatter(x=lengths, y=efold_GPU_dTs, mode='lines', name='eFold (GPU)')) -fig.update_layout(title='Inference time of eFold vs RNAstructure', - xaxis_title='Sequence length', yaxis_title='Time [s]', - template='plotly_white', font_size=22, margin=dict(l=100, r=20, t=100, b=100), - width=2000, height=1200) +fig.add_trace(go.Scatter(x=lengths, y=rnaStructure_dTs, mode="lines", name="RNAstructure")) +fig.add_trace(go.Scatter(x=lengths, y=efold_CPU_dTs, mode="lines", name="eFold (CPU)")) +fig.add_trace(go.Scatter(x=lengths, y=efold_GPU_dTs, mode="lines", name="eFold (GPU)")) +fig.update_layout( + title="Inference time of eFold vs RNAstructure", + xaxis_title="Sequence length", + yaxis_title="Time [s]", + template="plotly_white", + font_size=22, + margin=dict(l=100, r=20, t=100, b=100), + width=2000, + height=1200, +) # fig.show() -fig.write_image(os.path.join(file_dir, 'speed_comparison.jpg')) - +fig.write_image(os.path.join(file_dir, "speed_comparison.jpg")) From f96a971c14a7d4478eb7117904194cd5327fd38e Mon Sep 17 00:00:00 2001 From: taillades Date: Sun, 19 Oct 2025 14:08:34 -0700 Subject: [PATCH 2/4] ruff passes --- README.md | 132 ++++++++++++++++++++++++++++++++ efold/api/run.py | 11 +-- efold/cli.py | 1 + efold/core/batch.py | 10 +-- efold/core/callbacks.py | 7 +- efold/core/datamodule.py | 15 ++-- efold/core/dataset.py | 11 ++- efold/core/datatype.py | 1 - efold/core/embeddings.py | 2 +- efold/core/loader.py | 5 +- efold/core/logger.py | 3 +- efold/core/metrics.py | 2 +- efold/core/model.py | 9 ++- efold/core/path.py | 6 +- efold/core/postprocess.py | 10 +-- efold/core/sampler.py | 14 ++-- efold/core/visualisation.py | 2 +- efold/models/cnn.py | 4 +- efold/models/efold.py | 22 +++--- efold/models/ribonanza.py | 3 - efold/models/transformer.py | 6 +- efold/models/unet.py | 6 +- efold/settings.py | 2 +- scripts/cnn_template.py | 15 ++-- scripts/efold_training.py | 5 +- scripts/mlp-template.py | 12 ++- scripts/ribonanza-template.py | 7 +- scripts/transformer-template.py | 12 +-- scripts/unet_training.py | 9 +-- setup.py | 2 +- tests/test_eFold.ipynb | 10 +-- tests/test_speed_eFold.py | 18 ++--- 32 files changed, 248 insertions(+), 126 deletions(-) diff --git a/README.md b/README.md index e69de29..7745013 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,132 @@ +# eFold + +This repo contains the pytorch code for our paper “*Diverse Database and Machine Learning Model to narrow the generalization gap in RNA structure prediction”* + +[[BioRXiv](https://www.biorxiv.org/content/10.1101/2024.01.24.577093v1.full)] [[Data](https://huggingface.co/rouskinlab)] + +## Install + +```bash +pip install efold +``` + + +## Inference mode + +### Using the command line + +From a sequence: + +```bash +efold AAACAUGAGGAUUACCCAUGU -o seq.txt +cat seq.txt + +AAACAUGAGGAUUACCCAUGU +..(((((.((....))))))) +``` + +or a fasta file: + +```bash +efold --fasta example.fasta +``` + +Using different formats: +```bash +efold AAACAUGAGGAUUACCCAUGU -bp # base pairs +efold AAACAUGAGGAUUACCCAUGU -db # dotbracket (default) +``` + +Output can be .json, .csv or .txt +```bash +efold AAACAUGAGGAUUACCCAUGU -o output.csv +``` + +Run help: +```bash +efold -h +``` + +### Using python + +```python +>>> from efold.api import run +>>> run.run('AAACAUGAGGAUUACCCAUGU', fmt='dotbracket') +..(((((.((....))))))) +``` + +## Inference speed +Tested on a AMD EPYC 7272 12 core processor, with 32GB RAM and a RTX3090 GPU + +![alt text](tests/speed_comparison.jpg) + +## File structure + +```bash +efold/ + api/ # for inference calls + core/ # backend + models/ # where we define eFold and other models + resources/ + efold_weights.py # our best model weights +scripts/ + efold_training.py # our training script + [...] +LICENSE +requirements.txt +pyproject.toml +``` + +## Data + +### List of the datasets we used + +A breakdown of the data we used is summarized [here](https://github.com/rouskinlab/efold_data). All the data is stored on the [HuggingFace](https://huggingface.co/rouskinlab). + +### Get the data + +You can download our datasets using [rouskinHF](https://github.com/rouskinlab/rouskinhf): + +```bash +pip install rouskinhf +``` + +And in your code, write: + +```python +>>> import rouskinhf +>>> data = rouskinhf.get_dataset('ribo500-blast') # look at the dataset names on huggingface +``` + + + +## Reproducing our results + +### Training + +A [training script](scripts/efold_training.py) is provided to train eFold from scratch. + +### Testing + +A [notebook](tests/test_eFold.ipynb) is provided to run eFold inference on the four test sets, compute the F1 score and check the validity of the structures. + + +## Citation + +**Plain text:** + +Albéric A. de Lajarte, Yves J. Martin des Taillades, Colin Kalicki, Federico Fuchs Wightman, Justin Aruda, Dragui Salazar, Matthew F. Allan, Casper L’Esperance-Kerckhoff, Alex Kashi, Fabrice Jossinet, Silvi Rouskin. “Diverse Database and Machine Learning Model to narrow the generalization gap in RNA structure prediction”. bioRxiv 2024.01.24.577093; doi: https://doi.org/10.1101/2024.01.24.577093. 2024 + +**BibTex:** + +``` +@article {Lajarte_Martin_2024, + title = {Diverse Database and Machine Learning Model to narrow the generalization gap in RNA structure prediction}, + author = {Alb{\'e}ric A. de Lajarte and Yves J. Martin des Taillades and Colin Kalicki and Federico Fuchs Wightman and Justin Aruda and Dragui Salazar and Matthew F. Allan and Casper L{\textquoteright}Esperance-Kerckhoff and Alex Kashi and Fabrice Jossinet and Silvi Rouskin}, + year = {2024}, + doi = {10.1101/2024.01.24.577093}, + URL = {https://www.biorxiv.org/content/early/2024/01/25/2024.01.24.577093}, + journal = {bioRxiv} +} + +``` diff --git a/efold/api/run.py b/efold/api/run.py index 864baf1..f30da30 100644 --- a/efold/api/run.py +++ b/efold/api/run.py @@ -1,9 +1,10 @@ -import numpy as np import os -import torch -from os.path import join, dirname +from os.path import dirname, join from typing import List, Union +import numpy as np +import torch + from efold.core import batch, embeddings, postprocess from efold.models import factory from efold.util import format_conversion @@ -87,7 +88,7 @@ def run( if not os.path.exists(arg): raise ValueError("File not found") sequences = _load_sequences_from_fasta(arg) - elif type(arg) == str: + elif isinstance(arg, str): sequences = [arg] elif hasattr(arg, "__iter__") and all([isinstance(s, str) for s in arg]): sequences = arg @@ -126,7 +127,7 @@ def run( structure = _predict_structure(model, seq, device=device) if fmt == "dotbracket": db_structure = format_conversion.convert_bp_list_to_dotbracket(structure, len(seq)) - if db_structure != None: + if db_structure is not None: structure = db_structure structures.append(structure) diff --git a/efold/cli.py b/efold/cli.py index 2f4970e..2c626f0 100644 --- a/efold/cli.py +++ b/efold/cli.py @@ -1,4 +1,5 @@ import json + import click from efold.api import run diff --git a/efold/core/batch.py b/efold/core/batch.py index 2367366..bc8a73d 100644 --- a/efold/core/batch.py +++ b/efold/core/batch.py @@ -92,11 +92,11 @@ def from_dataset_items( [ embeddings.base_pairs_to_pairing_matrix( dp["structure"]["true"], - l, + len_, padding=L, pad_value=structure_padding_value, ) - for (dp, l) in zip(batch_data, length) + for (dp, len_) in zip(batch_data, length) ] ), error=None, @@ -151,9 +151,9 @@ def get(self, data_type: str, index: int = None, to_numpy: bool = False): if index is not None: out = out[index] if hasattr(out, "__len__"): - l = self.get("length")[index] + len_ = self.get("length")[index] if data_type == "structure": - out = out[:l, :l] + out = out[:len_, :len_] else: out = out[: self.get("length")[index]] @@ -182,7 +182,7 @@ def get_pairs(self, data_type: str, to_numpy: bool = False) -> tuple: def count(self, data_type: str) -> int: if data_type in ["reference", "sequence", "length"]: return self.batch_size - if not data_type in self.dt_count or getattr(self, data_type) is None: + if data_type not in self.dt_count or getattr(self, data_type) is None: return 0 return self.dt_count[data_type] diff --git a/efold/core/callbacks.py b/efold/core/callbacks.py index 456123e..929c12d 100644 --- a/efold/core/callbacks.py +++ b/efold/core/callbacks.py @@ -1,10 +1,9 @@ import lightning.pytorch as pl -from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.utilities import rank_zero_only import wandb +from lightning.pytorch import Trainer +from lightning.pytorch.utilities import rank_zero_only -from efold import settings -from efold.core import batch, datamodule, loader, logger, metrics, visualisation +from efold.core import loader class ModelCheckpoint(pl.Callback): diff --git a/efold/core/datamodule.py b/efold/core/datamodule.py index 8114561..8316529 100644 --- a/efold/core/datamodule.py +++ b/efold/core/datamodule.py @@ -1,8 +1,8 @@ import datetime +from typing import List, Union + import lightning.pytorch as pl -import numpy as np -from torch.utils.data import random_split, Subset -from typing import Union, List +from torch.utils.data import Subset from efold import settings from efold.core import dataloader, dataset, sampler @@ -64,7 +64,7 @@ def __init__( "predict": predict_split, } if strategy in ["ddp", "sorted"]: - assert shuffle_valid == shuffle_train == False, ( + assert shuffle_valid == shuffle_train is False, ( "You can't shuffle in ddp or sorted mode. Set shuffle_train and shuffle_valid to 0 or use strategy='random'." ) self.shuffle = { @@ -78,7 +78,6 @@ def __init__( "use_error": use_error, "force_download": force_download, "tqdm": tqdm, - "max_len": max_len, "min_len": min_len, } self.buckets = buckets @@ -96,8 +95,8 @@ def _use_multiple_datasets(self, name): def _dataset_merge(self, datasets): merge = datasets[0] collate_fn = merge.collate_fn - for dataset in datasets[1:]: - merge = merge + dataset + for ds in datasets[1:]: + merge = merge + ds merge.collate_fn = collate_fn return merge @@ -117,7 +116,7 @@ def setup(self, stage: str = None): self.collate_fn = self.all_datasets.collate_fn if stage == "fit": - if self.splits["train"] == None or self.splits["train"] == 1.0: + if self.splits["train"] is None or self.splits["train"] == 1.0: self.train_set = self.all_datasets else: num_datapoints = ( diff --git a/efold/core/dataset.py b/efold/core/dataset.py index 5810d57..aa64535 100644 --- a/efold/core/dataset.py +++ b/efold/core/dataset.py @@ -1,13 +1,12 @@ import os -import numpy as np -import torch -from torch.utils.data import ConcatDataset, Dataset as TorchDataset, Dataset from typing import List +import numpy as np from rouskinhf import get_dataset +from torch.utils.data import Dataset as TorchDataset from efold import settings -from efold.core import batch, datatype, embeddings, path, util +from efold.core import batch, datatype, path class Dataset(TorchDataset): @@ -51,7 +50,7 @@ def _remove_sequences_out_of_length_interval(self, min_len, max_len): min_len = 0 if min_len > max_len: raise ValueError("min_len must be smaller than max_len") - idx_out = [i for i, l in enumerate(self.length) if l >= max_len or l <= min_len] + idx_out = [i for i, len_ in enumerate(self.length) if len_ >= max_len or len_ <= min_len] for idx in idx_out[::-1]: del self.refs[idx] del self.length[idx] @@ -200,7 +199,7 @@ def __getitem__(self, index) -> tuple: "length": self.length[index], } for attr in ["dms", "shape", "structure"]: - out[attr] = getattr(self, attr)[index] if getattr(self, attr) != None else None + out[attr] = getattr(self, attr)[index] if getattr(self, attr) is not None else None return out def collate_fn(self, batch_data): diff --git a/efold/core/datatype.py b/efold/core/datatype.py index 019cd0c..007e329 100644 --- a/efold/core/datatype.py +++ b/efold/core/datatype.py @@ -1,7 +1,6 @@ import torch from efold import settings -from efold.core import util class DataType: diff --git a/efold/core/embeddings.py b/efold/core/embeddings.py index e8d0d07..daa55e9 100644 --- a/efold/core/embeddings.py +++ b/efold/core/embeddings.py @@ -1,5 +1,5 @@ -from torch import nn import torch +from torch import nn from efold import settings diff --git a/efold/core/loader.py b/efold/core/loader.py index ff6d33d..403df57 100644 --- a/efold/core/loader.py +++ b/efold/core/loader.py @@ -1,7 +1,8 @@ +import os +from os import listdir, makedirs from os.path import dirname -from os import makedirs, listdir + import torch -import os class Loader: diff --git a/efold/core/logger.py b/efold/core/logger.py index a6fbeab..61ec862 100644 --- a/efold/core/logger.py +++ b/efold/core/logger.py @@ -1,7 +1,8 @@ -import wandb import os + import lightning.pytorch as pl import matplotlib.pyplot as plt +import wandb from efold.settings import * diff --git a/efold/core/metrics.py b/efold/core/metrics.py index e64f09e..93b7be5 100644 --- a/efold/core/metrics.py +++ b/efold/core/metrics.py @@ -138,7 +138,7 @@ def update(self, batch: batch.Batch): return self def compute(self) -> dict: - out = {} + out: dict = {} for dt in self.data_type: out[dt] = {} for metric in settings.POSSIBLE_METRICS[dt]: diff --git a/efold/core/model.py b/efold/core/model.py index f89886e..d8b55f6 100644 --- a/efold/core/model.py +++ b/efold/core/model.py @@ -1,9 +1,10 @@ from typing import Any + import lightning.pytorch as pl -from lightning.pytorch.utilities.types import STEP_OUTPUT import torch import torch.nn as nn import torch.nn.functional as F +from lightning.pytorch.utilities.types import STEP_OUTPUT from efold import settings from efold.core import batch, metrics, postprocess @@ -126,7 +127,7 @@ def training_step(self, batch: batch.Batch, batch_idx: int): predictions = self.forward(batch) batch.integrate_prediction(predictions) loss = self.loss_fn(batch)[0] - self.log(f"train/loss", loss, sync_dist=True) + self.log("train/loss", loss, sync_dist=True) return loss def on_validation_start(self): @@ -168,8 +169,8 @@ def on_validation_epoch_end(self) -> None: # aggregate the stack and log it for metrics_dl in self.metrics_stack: metrics_pack = metrics_dl.compute() - for dt, metrics in metrics_pack.items(): - for name, metric in metrics.items(): + for dt, metrics_dict in metrics_pack.items(): + for name, metric in metrics_dict.items(): # to replace with a gather_all? self.log( f"valid/{metrics_dl.name}/{dt}/{name}", diff --git a/efold/core/path.py b/efold/core/path.py index 50a70c4..4d79cdd 100644 --- a/efold/core/path.py +++ b/efold/core/path.py @@ -1,8 +1,8 @@ -from os.path import join import os -from rouskinhf.env import Env -import numpy as np import pickle +from os.path import join + +import numpy as np from rouskinhf.path import Path as RouskinPath diff --git a/efold/core/postprocess.py b/efold/core/postprocess.py index 0903207..2797a1e 100644 --- a/efold/core/postprocess.py +++ b/efold/core/postprocess.py @@ -42,7 +42,7 @@ def mask_sharpLoops(self, input_matrix, min_hairpin_length): def mask_nonCanonical(self, sequence): # Embed sequence - if type(sequence) == str: + if isinstance(sequence, str): sequence = torch.tensor([settings.seq2int[a] for a in sequence]) # make the pairing matrix @@ -61,14 +61,14 @@ def mask_nonCanonical(self, sequence): class HungarianAlgorithm: def run(self, bppm, threshold=0.5): """Runs the Hungarian algorithm on the input bppm matrix - + Args: - bppm (torch.Tensor): n x n matrix of base pair probabilities - + Example: >>> inpt = np.diag(np.ones(10))[::-1] >>> inpt += np.random.normal(0, 0.2, inpt.shape) - >>> inpt = (inpt + inpt.T)/2 + >>> inpt = (inpt + inpt.T)/2 >>> inpt = torch.tensor(inpt) >>> out = HungarianAlgorithm().run(inpt) >>> assert (out == [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],\ @@ -91,7 +91,7 @@ def run(self, bppm, threshold=0.5): assert self.is_symmetric(bppm), f"The input bppm matrix should be symmetric, {bppm}" # just work with numpy (needed for the optimization step) - if type(bppm) == torch.Tensor: + if isinstance(bppm, torch.Tensor): device = bppm.device bppm = bppm.cpu().numpy() diff --git a/efold/core/sampler.py b/efold/core/sampler.py index fbf44d2..b253d7c 100644 --- a/efold/core/sampler.py +++ b/efold/core/sampler.py @@ -1,13 +1,13 @@ -from torch.utils.data import Sampler, Subset +import math +import os +from typing import Iterator, Optional, TypeVar, Union + import numpy as np +import torch +import torch.distributed as dist # from random import shuffle -from torch.utils.data import Dataset -from typing import Union, Optional, TypeVar, Iterator -import torch.distributed as dist -import math -import torch -import os +from torch.utils.data import Dataset, Sampler, Subset T_co = TypeVar("T_co", covariant=True) diff --git a/efold/core/visualisation.py b/efold/core/visualisation.py index 0a6f281..3797cf8 100644 --- a/efold/core/visualisation.py +++ b/efold/core/visualisation.py @@ -1,5 +1,5 @@ -from matplotlib import pyplot as plt import numpy as np +from matplotlib import pyplot as plt from rouskinhf import int2seq from efold import settings diff --git a/efold/models/cnn.py b/efold/models/cnn.py index 2a02820..4e11e4f 100644 --- a/efold/models/cnn.py +++ b/efold/models/cnn.py @@ -1,8 +1,8 @@ import numpy as np import torch -from torch import nn, Tensor -from einops import rearrange import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, nn from efold.core import batch, model diff --git a/efold/models/efold.py b/efold/models/efold.py index be9ca9e..6fda990 100644 --- a/efold/models/efold.py +++ b/efold/models/efold.py @@ -1,10 +1,12 @@ +from collections import defaultdict +from contextlib import ExitStack +from typing import List, Union + import numpy as np import torch -from torch import nn, Tensor -from contextlib import ExitStack -from einops import rearrange import torch.nn.functional as F -from collections import defaultdict +from einops import rearrange +from torch import Tensor, nn from efold.core import batch, model @@ -471,7 +473,7 @@ class Dropout(nn.Module): along a particular dimension. """ - def __init__(self, r: float, batch_dim: T.Union[int, T.List[int]]): + def __init__(self, r: float, batch_dim: Union[int, List[int]]): super(Dropout, self).__init__() self.r = r @@ -733,7 +735,7 @@ def __init__( self.dropout = nn.Dropout(dropout) self._dropout_rate = dropout - input_max = (self.num_heads * self.head_size) ** -0.5 + (self.num_heads * self.head_size) ** -0.5 ### DOUBLE CHECK THIS CODE: self.query = nn.Linear(num_heads * head_size, num_heads * head_size, bias=False) self.key = nn.Linear(num_heads * head_size, num_heads * head_size, bias=False) @@ -843,7 +845,7 @@ def __init__(self, kernel_sizes=None, strides=None, **kwargs): super(RelPositionMultiHeadAttention, self).__init__(**kwargs) num_pos_features = self.num_heads * self.head_size - input_max = (self.num_heads * self.head_size) ** -0.5 + (self.num_heads * self.head_size) ** -0.5 self.pos_kernel = nn.Parameter( torch.rand(self.num_heads, num_pos_features, self.head_size) * 2 - 1 ) @@ -921,9 +923,9 @@ def __init__( self.scale = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) - pw1_max = input_dim**-0.5 - dw_max = kernel_size**-0.5 - pw2_max = input_dim**-0.5 + input_dim**-0.5 + kernel_size**-0.5 + input_dim**-0.5 self.pw_conv_1 = nn.Conv1d( in_channels=input_dim, diff --git a/efold/models/ribonanza.py b/efold/models/ribonanza.py index bae485d..bed56fd 100644 --- a/efold/models/ribonanza.py +++ b/efold/models/ribonanza.py @@ -77,9 +77,6 @@ def forward(self, structure): return structure * weights.unsqueeze(-1).unsqueeze(-1) -from torch.nn.functional import multi_head_attention_forward - - class SelfAttention(nn.Module): def __init__( self, diff --git a/efold/models/transformer.py b/efold/models/transformer.py index 9f762b6..1fbc0e9 100644 --- a/efold/models/transformer.py +++ b/efold/models/transformer.py @@ -1,6 +1,6 @@ import numpy as np import torch -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn import TransformerEncoderLayer from efold.core import batch, model @@ -100,7 +100,7 @@ def forward(self, batch: batch.Batch) -> Tensor: src = self.encoder(src) src = self.pos_encoder(src) - for i, l in enumerate(self.transformer_encoder): + for i, _ in enumerate(self.transformer_encoder): src = self.transformer_encoder[i](src) src = self.resnet(src.unsqueeze(dim=1)).squeeze(dim=1) @@ -156,7 +156,7 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): # Basic Residula block self.res_layers = [] for i in range(n_blocks): - dilation = pow(2, (i % 3)) + pow(2, (i % 3)) self.res_layers.append( ResBlock( inplanes=dim_in, diff --git a/efold/models/unet.py b/efold/models/unet.py index 50f24e3..8f274d6 100644 --- a/efold/models/unet.py +++ b/efold/models/unet.py @@ -1,8 +1,8 @@ -import torch -from torch import nn, Tensor -import torch.nn.functional as F from collections import defaultdict +import torch +from torch import Tensor, nn + from efold import settings from efold.core import batch, model diff --git a/efold/settings.py b/efold/settings.py index 5b2837d..4e5208e 100644 --- a/efold/settings.py +++ b/efold/settings.py @@ -3,7 +3,7 @@ import torch import yaml -from torch import backends, cuda, float32 +from torch import cuda, float32 _settings_path = Path(__file__).parent / "settings.yaml" diff --git a/scripts/cnn_template.py b/scripts/cnn_template.py index 01aca86..1b6a4aa 100644 --- a/scripts/cnn_template.py +++ b/scripts/cnn_template.py @@ -5,18 +5,17 @@ import os import sys + import wandb -from lightning.pytorch.strategies import DDPStrategy -from lightning.pytorch.loggers import WandbLogger -import pandas as pd -from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch import Trainer from lightning.pytorch.callbacks import LearningRateMonitor -from lightning.pytorch.profilers import PyTorchProfiler +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.strategies import DDPStrategy from efold import settings from efold.core import callbacks, datamodule from efold.models import factory + # import envbash # envbash.load.load_envbash('.env') @@ -28,7 +27,7 @@ n_gpu = 8 USE_WANDB = 0 STRATEGY = "random" - print("Running on device: {}".format(device)) + print("Running on device: {}".format(settings.device)) if USE_WANDB: project = "Structure-classic" wandb_logger = WandbLogger(project=project) @@ -69,7 +68,7 @@ model.load_state_dict( torch.load( "/Users/alberic/Desktop/lively-waterfall-8_epoch45.pt", - map_location=torch.device(device), + map_location=torch.device(settings.device), ) ) @@ -77,7 +76,7 @@ wandb_logger.watch(model, log="all") trainer = Trainer( - accelerator=device, + accelerator=settings.device, devices=n_gpu if STRATEGY == "ddp" else 1, strategy=DDPStrategy(find_unused_parameters=False) if STRATEGY == "ddp" else "auto", # precision="16-mixed", diff --git a/scripts/efold_training.py b/scripts/efold_training.py index 7059e5a..3943db9 100644 --- a/scripts/efold_training.py +++ b/scripts/efold_training.py @@ -3,17 +3,16 @@ sys.path.append(os.path.abspath(".")) -from lightning.pytorch.strategies import DDPStrategy import wandb -from lightning.pytorch.loggers import WandbLogger from lightning.pytorch import Trainer from lightning.pytorch.callbacks import LearningRateMonitor +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.strategies import DDPStrategy from efold import settings from efold.core import callbacks, datamodule from efold.models import factory - # Train loop if __name__ == "__main__": USE_WANDB = False diff --git a/scripts/mlp-template.py b/scripts/mlp-template.py index 732f5c6..8544c6f 100644 --- a/scripts/mlp-template.py +++ b/scripts/mlp-template.py @@ -1,14 +1,12 @@ -import numpy as np import os -import pandas as pd import sys + import wandb -from lightning.pytorch.callbacks.early_stopping import EarlyStopping -from lightning.pytorch.loggers import WandbLogger from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger from efold import settings -from efold.core import callbacks, datamodule, metrics +from efold.core import callbacks, datamodule from efold.models import factory sys.path.append(os.path.abspath(".")) @@ -20,7 +18,7 @@ # why do you need this? if __name__ == "__main__": - print("Running on device: {}".format(device)) + print("Running on device: {}".format(settings.device)) BATCH_SIZE = 4 LR = 5e-5 @@ -73,7 +71,7 @@ max_epochs=1, log_every_n_steps=1, logger=wandb_logger, - accelerator=device, + accelerator=settings.device, callbacks=[ callbacks.ModelCheckpoint(every_n_epoch=1), ], diff --git a/scripts/ribonanza-template.py b/scripts/ribonanza-template.py index 2489f71..e08caa8 100644 --- a/scripts/ribonanza-template.py +++ b/scripts/ribonanza-template.py @@ -1,10 +1,11 @@ import os import sys + import torch import wandb +from lightning.pytorch import Trainer from lightning.pytorch.callbacks import LearningRateMonitor from lightning.pytorch.loggers import WandbLogger -from lightning.pytorch import Trainer from lightning.pytorch.strategies import DDPStrategy sys.path.append(os.path.dirname(os.path.dirname(__file__))) @@ -15,7 +16,7 @@ if __name__ == "__main__": USE_WANDB = True - print("Running on device: {}".format(device)) + print("Running on device: {}".format(settings.device)) if USE_WANDB: wandb_logger = WandbLogger(project="ribonanza-solution", name="first-run") @@ -61,7 +62,7 @@ devices=8, strategy=DDPStrategy(find_unused_parameters=True), max_epochs=1000, - accelerator=device, + accelerator=settings.device, logger=wandb_logger if USE_WANDB else None, callbacks=[ LearningRateMonitor(logging_interval="epoch"), diff --git a/scripts/transformer-template.py b/scripts/transformer-template.py index e219619..f4dca33 100644 --- a/scripts/transformer-template.py +++ b/scripts/transformer-template.py @@ -1,11 +1,13 @@ -import envbash import os import sys + +import envbash import wandb +from lightning.pytorch import Trainer from lightning.pytorch.callbacks import LearningRateMonitor from lightning.pytorch.loggers import WandbLogger -from lightning.pytorch import Trainer +from efold import settings from efold.core import callbacks, datamodule from efold.models import factory @@ -16,7 +18,7 @@ if __name__ == "__main__": USE_WANDB = True - print("Running on device: {}".format(device)) + print("Running on device: {}".format(settings.device)) if USE_WANDB: wandb_logger = WandbLogger(project="CHANGE_ME", name="debug") @@ -63,7 +65,7 @@ # train with both splits trainer = Trainer( - accelerator=device, + accelerator=settings.device, # devices=4, # strategy="ddp", # precision="16-mixed", @@ -95,7 +97,7 @@ ) trainer = Trainer( - accelerator=device, + accelerator=settings.device, devices=1, callbacks=[ # don't change this diff --git a/scripts/unet_training.py b/scripts/unet_training.py index 87d6fac..4365743 100644 --- a/scripts/unet_training.py +++ b/scripts/unet_training.py @@ -4,21 +4,20 @@ sys.path.append(os.path.abspath(".")) import wandb -from lightning.pytorch.strategies import DDPStrategy -from lightning.pytorch.loggers import WandbLogger from lightning.pytorch import Trainer from lightning.pytorch.callbacks import LearningRateMonitor +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.strategies import DDPStrategy from efold import settings from efold.core import callbacks, datamodule from efold.models import factory - # Train loop if __name__ == "__main__": USE_WANDB = False STRATEGY = "random" - print("Running on device: {}".format(device)) + print("Running on device: {}".format(settings.device)) if USE_WANDB: wandb_logger = WandbLogger(project="test") @@ -50,7 +49,7 @@ wandb_logger.watch(model, log="all") trainer = Trainer( - accelerator=device, + accelerator=settings.device, devices=8 if STRATEGY == "ddp" else 1, strategy=DDPStrategy(find_unused_parameters=False) if STRATEGY == "ddp" else "auto", max_epochs=15, diff --git a/setup.py b/setup.py index 8e26cc9..7c9b84b 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name="efold", diff --git a/tests/test_eFold.ipynb b/tests/test_eFold.ipynb index 67a850f..180d58d 100644 --- a/tests/test_eFold.ipynb +++ b/tests/test_eFold.ipynb @@ -6,14 +6,12 @@ "metadata": {}, "outputs": [], "source": [ - "import efold.core as core\n", + "import os\n", "\n", - "import pandas as pd\n", "import numpy as np\n", - "from rouskinhf import get_dataset\n", + "import pandas as pd\n", "import torch\n", - "\n", - "import os\n", + "from rouskinhf import get_dataset\n", "\n", "os.environ[\"HUGGINGFACE_TOKEN\"] = \"your key here\"" ] @@ -142,7 +140,9 @@ "outputs": [], "source": [ "import time\n", + "\n", "from tqdm import tqdm\n", + "\n", "from efold.api.run import run\n", "\n", "eFold_processed = pd.DataFrame()\n", diff --git a/tests/test_speed_eFold.py b/tests/test_speed_eFold.py index 4521c6e..91b11dc 100644 --- a/tests/test_speed_eFold.py +++ b/tests/test_speed_eFold.py @@ -1,18 +1,12 @@ -import sys, os - -file_dir = os.path.dirname(os.path.realpath(__file__)) -sys.path.append(os.path.join(file_dir, "../../eFold")) +import os +import time -from efold.api import run as run_module -import pandas as pd import numpy as np -from rouskinhf import get_dataset -import torch - -import time +from rnastructure_wrapper import RNAstructure from tqdm import tqdm +import plotly.graph_objects as go -from rnastructure_wrapper import RNAstructure +from efold.api import run as run_module Fold = RNAstructure(path="/root/RNAstructure/exe/") @@ -42,8 +36,6 @@ dT = time.time() - t0 rnaStructure_dTs.append(dT) -import plotly.graph_objects as go - fig = go.Figure() fig.add_trace(go.Scatter(x=lengths, y=rnaStructure_dTs, mode="lines", name="RNAstructure")) fig.add_trace(go.Scatter(x=lengths, y=efold_CPU_dTs, mode="lines", name="eFold (CPU)")) From eccf30178b0b89bbffc9882f0a0fcbe609481437 Mon Sep 17 00:00:00 2001 From: taillades Date: Sun, 19 Oct 2025 14:08:40 -0700 Subject: [PATCH 3/4] ruff passes --- efold/core/logger.py | 2 -- efold/core/postprocess.py | 3 ++- efold/core/sampler.py | 3 ++- efold/models/efold.py | 3 ++- efold/models/unet.py | 3 ++- efold/util/format_conversion.py | 6 ++++-- pyproject.toml | 3 ++- tests/test_speed_eFold.py | 2 +- 8 files changed, 15 insertions(+), 10 deletions(-) diff --git a/efold/core/logger.py b/efold/core/logger.py index 61ec862..8ae4991 100644 --- a/efold/core/logger.py +++ b/efold/core/logger.py @@ -4,8 +4,6 @@ import matplotlib.pyplot as plt import wandb -from efold.settings import * - class LocalLogger: def __init__(self, path: str = "local_testing_output", overwrite: bool = False): diff --git a/efold/core/postprocess.py b/efold/core/postprocess.py index 2797a1e..30bb39c 100644 --- a/efold/core/postprocess.py +++ b/efold/core/postprocess.py @@ -186,7 +186,8 @@ def contact_a(a_hat, m): # grad_a = (lmbd * soft_sign(torch.sum(contact_a(a_hat, m), dim=-1) - 1)).unsqueeze_(-1).expand(u.shape) - u / 2 # grad = a_hat * m * (grad_a + torch.transpose(grad_a, -1, -2)) # n2 = torch.norm(grad) - # print([t, 'norms', n1, n2, aug_lagrangian(u, m, a_hat, lmbd), torch.sum(contact_a(a_hat, u))]) + # print([t, 'norms', n1, n2, aug_lagrangian(u, m, a_hat, lmbd), + # torch.sum(contact_a(a_hat, u))]) a = a_hat * a_hat a = (a + torch.transpose(a, -1, -2)) / 2 diff --git a/efold/core/sampler.py b/efold/core/sampler.py index b253d7c..9d5a684 100644 --- a/efold/core/sampler.py +++ b/efold/core/sampler.py @@ -80,7 +80,8 @@ def __init__( rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( - f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}] because num_replicas={num_replicas}" + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]", + f"because num_replicas={num_replicas}", ) self.dataset = dataset self.num_replicas = num_replicas diff --git a/efold/models/efold.py b/efold/models/efold.py index 6fda990..bd7338d 100644 --- a/efold/models/efold.py +++ b/efold/models/efold.py @@ -773,7 +773,8 @@ def call_qkv(self, query, key, value, training=False): return query, key, value def call_attention(self, query, key, value, logits, bias=None, training=False, mask=None): - # Mask = attention mask with shape [B, Tquery, Tkey] with 1 for positions we want to attend, 0 for masked + # Mask = attention mask with shape [B, Tquery, Tkey] with 1 for positions we want to attend, + # 0 for masked if mask is not None: if len(mask.size()) < 2: raise ValueError("'mask' must have at least 2 dimensions") diff --git a/efold/models/unet.py b/efold/models/unet.py index 8f274d6..b398bd6 100644 --- a/efold/models/unet.py +++ b/efold/models/unet.py @@ -60,7 +60,8 @@ def forward(self, batch: batch.Batch) -> Tensor: # return l # pad_len = get_cut_len(src.shape[1], 80)-src.shape[1] - # src = torch.cat( (src, torch.zeros((src.shape[0], pad_len), device=self.device, dtype=torch.long) ), dim=-1) + # src = torch.cat( (src, torch.zeros((src.shape[0], pad_len), device=self.device, + # dtype=torch.long) ), dim=-1) x = self.seq2map(src) diff --git a/efold/util/format_conversion.py b/efold/util/format_conversion.py index 27d2c83..7882433 100644 --- a/efold/util/format_conversion.py +++ b/efold/util/format_conversion.py @@ -86,9 +86,11 @@ def _get_non_redudant_bp_list(conflict_list: list) -> list: def _get_list_bp_conflicts(bp_list: list[tuple[int, int]]) -> list: """given a bp_list gives the list of conflicts bp-s which indicate PK structure Args: - bp_list: of list of base pairs where the base pairs are list of indeces of the bp in increasing order (bp[0] Date: Sun, 19 Oct 2025 14:56:41 -0700 Subject: [PATCH 4/4] remove config.py, use constants config instead --- MANIFEST.in | 3 +- efold/api/run.py | 4 +- efold/constants.py | 199 ++++++++++++++++++++++++++++++++ efold/constants.yaml | 62 ++++++++++ efold/core/batch.py | 16 +-- efold/core/datamodule.py | 10 +- efold/core/dataset.py | 12 +- efold/core/datatype.py | 10 +- efold/core/embeddings.py | 12 +- efold/core/loader.py | 3 +- efold/core/metrics.py | 8 +- efold/core/model.py | 14 +-- efold/core/postprocess.py | 8 +- efold/core/sampler.py | 4 +- efold/core/util.py | 6 +- efold/core/visualisation.py | 4 +- efold/models/cnn.py | 2 +- efold/models/efold.py | 11 +- efold/models/ribonanza.py | 12 +- efold/models/transformer.py | 2 +- efold/models/unet.py | 6 +- efold/settings.py | 56 --------- efold/settings.yaml | 66 ----------- efold/util/format_conversion.py | 3 +- scripts/cnn_template.py | 8 +- scripts/efold_training.py | 6 +- scripts/mlp-template.py | 6 +- scripts/ribonanza-template.py | 6 +- scripts/transformer-template.py | 8 +- scripts/unet_training.py | 6 +- setup.py | 2 +- 31 files changed, 356 insertions(+), 219 deletions(-) create mode 100644 efold/constants.py create mode 100644 efold/constants.yaml delete mode 100644 efold/settings.py delete mode 100644 efold/settings.yaml diff --git a/MANIFEST.in b/MANIFEST.in index 0674363..d54bfb5 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1 @@ -requirements.txt -include efold/settings.yaml \ No newline at end of file +requirements.txt \ No newline at end of file diff --git a/efold/api/run.py b/efold/api/run.py index f30da30..69ad475 100644 --- a/efold/api/run.py +++ b/efold/api/run.py @@ -1,6 +1,6 @@ import os from os.path import dirname, join -from typing import List, Union +from typing import List, Optional, Union import numpy as np import torch @@ -53,7 +53,7 @@ def _predict_structure(model, sequence: str, device: str = "cpu") -> list[tuple[ def run( - arg: Union[str, List[str]] = None, fmt: str = "dotbracket", device: str = None + arg: Optional[Union[str, List[str]]] = None, fmt: str = "dotbracket", device: Optional[str] = None ) -> dict[str, Union[str, list[tuple[int, int]]]]: """Runs the Efold API on the provided sequence or fasta file. diff --git a/efold/constants.py b/efold/constants.py new file mode 100644 index 0000000..e0a6d5c --- /dev/null +++ b/efold/constants.py @@ -0,0 +1,199 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import torch +import yaml +from torch import cuda + + +@dataclass +class TokenMapping: + """Token mappings for nucleotides. + + :param seq2int: Mapping from sequence characters to integers + :param start_token: Optional start token + :param end_token: Optional end token + :param padding_token_key: Key used for padding token + """ + seq2int: dict[str, int] + start_token: Optional[str] + end_token: Optional[str] + padding_token_key: str + + @property + def int2seq(self) -> dict[int, str]: + """Reverse mapping from integers to sequences. + + :return: Dictionary mapping integers to sequence characters + """ + return {v: k for k, v in self.seq2int.items()} + + @property + def padding_token(self) -> int: + """Get the padding token integer value. + + :return: Integer value of padding token + """ + return self.seq2int[self.padding_token_key] + + +@dataclass +class PyTorchConfig: + """PyTorch configuration settings. + + :param unknown_value: Value used for unknown/missing data + :param val_gu: Value for GU base pairing + """ + unknown_value: float + val_gu: float + + +@dataclass +class TestSets: + """Test set definitions. + + :param structure: List of structure test sets + :param sequence: List of sequence test sets + :param dms: List of DMS test sets + :param shape: List of SHAPE test sets + """ + structure: list[str] + sequence: list[str] + dms: list[str] + shape: list[str] + + @property + def as_dict(self) -> dict[str, list[str]]: + """Return test sets as dictionary. + + :return: Dictionary of test set types to names + """ + return { + "structure": self.structure, + "sequence": self.sequence, + "dms": self.dms, + "shape": self.shape + } + + @property + def all_names(self) -> list[str]: + """Get all test set names flattened. + + :return: List of all test set names + """ + return [name for sets in self.as_dict.values() for name in sets] + + @property + def data_types_per_test_set(self) -> list[str]: + """Get data type for each test set. + + :return: List of data types matching test sets + """ + return [dtype for dtype, names in self.as_dict.items() for _ in names] + + +@dataclass +class DataConfig: + """Data types and format configuration. + + :param types: List of available data types + :param types_format: Mapping of data types to their torch dtype + """ + types: list[str] + types_format: dict[str, str] + + @property + def types_format_torch(self) -> dict[str, torch.dtype]: + """Get data types as torch dtypes. + + :return: Dictionary mapping data types to torch.dtype objects + """ + dtype_map = { + "int32": torch.int32, + "float32": torch.float32, + "float64": torch.float64, + } + return {k: dtype_map[v] for k, v in self.types_format.items()} + + +@dataclass +class MetricsConfig: + """Metrics configuration. + + :param reference_metric: Primary metric for each data type + :param ref_metric_sign: Sign convention for metrics (1=higher is better, -1=lower is better) + :param possible_metrics: List of available metrics per data type + """ + reference_metric: dict[str, str] + ref_metric_sign: dict[str, int] + possible_metrics: dict[str, list[str]] + + +@dataclass +class Config: + """Main configuration class combining all config sections. + + :param tokens: Token mapping configuration + :param pytorch: PyTorch-specific configuration + :param test_sets: Test set definitions + :param data: Data type configuration + :param metrics: Metrics configuration + """ + tokens: TokenMapping + pytorch: PyTorchConfig + test_sets: TestSets + data: DataConfig + metrics: MetricsConfig + device: str = field(init=False) + + def __post_init__(self) -> None: + """Initialize device after other fields are set. + + :return: None + """ + object.__setattr__(self, 'device', "cuda" if cuda.is_available() else "cpu") + + +def _load_config() -> Config: + """Load configuration from YAML file. + + :return: Instantiated Config object + """ + config_path = Path(__file__).parent / "constants.yaml" + + with open(config_path, 'r') as f: + data = yaml.safe_load(f) + + return Config( + tokens=TokenMapping( + seq2int=data['tokens']['seq2int'], + start_token=data['tokens']['start_token'], + end_token=data['tokens']['end_token'], + padding_token_key=data['tokens']['padding_token_key'] + ), + pytorch=PyTorchConfig( + unknown_value=data['pytorch']['unknown_value'], + val_gu=data['pytorch']['val_gu'] + ), + test_sets=TestSets( + structure=data['test_sets']['structure'], + sequence=data['test_sets']['sequence'], + dms=data['test_sets']['dms'], + shape=data['test_sets']['shape'] + ), + data=DataConfig( + types=data['data']['types'], + types_format=data['data']['types_format'] + ), + metrics=MetricsConfig( + reference_metric=data['metrics']['reference_metric'], + ref_metric_sign=data['metrics']['ref_metric_sign'], + possible_metrics=data['metrics']['possible_metrics'] + ) + ) + + +config = _load_config() + +torch.set_default_dtype(torch.float32) diff --git a/efold/constants.yaml b/efold/constants.yaml new file mode 100644 index 0000000..f098739 --- /dev/null +++ b/efold/constants.yaml @@ -0,0 +1,62 @@ +# RNA Secondary Structure Prediction Configuration + +tokens: + seq2int: + X: 0 + A: 1 + C: 2 + G: 3 + U: 4 + + start_token: null + end_token: null + padding_token_key: X + +pytorch: + unknown_value: -1000.0 + val_gu: 0.095 + +test_sets: + structure: + - PDB + - archiveII + - lncRNA_nonFiltered + - viral_fragments + sequence: [] + dms: [] + shape: [] + +data: + types: + - structure + - dms + - shape + + types_format: + structure: int32 + dms: float32 + shape: float32 + +metrics: + reference_metric: + structure: f1 + dms: mae + shape: mae + + ref_metric_sign: + structure: 1 + dms: -1 + shape: -1 + + possible_metrics: + structure: + - f1 + dms: + - mae + - r2 + - pearson + shape: + - mae + - r2 + - pearson + diff --git a/efold/core/batch.py b/efold/core/batch.py index bc8a73d..f84cf84 100644 --- a/efold/core/batch.py +++ b/efold/core/batch.py @@ -1,15 +1,17 @@ +from typing import Optional + import torch import torch.nn.functional as F -from efold import settings +from efold.constants import config from efold.core import datatype, embeddings, util def _pad(arr: torch.Tensor, L: int, data_type: str, accept_none: bool = False) -> torch.Tensor: padding_values = { "sequence": 0, - "dms": settings.UKN, - "shape": settings.UKN, + "dms": config.pytorch.unknown_value, + "shape": config.pytorch.unknown_value, } assert data_type in padding_values.keys(), ( f"Unknown data type {data_type}. If you want to pad a structure, use base_pairs_to_pairing_matrix." @@ -21,9 +23,9 @@ def _pad(arr: torch.Tensor, L: int, data_type: str, accept_none: bool = False) - def get_padded_vector(dp: dict, data_type: str, data_part: str, L: int) -> torch.Tensor: if getattr(dp, data_type) is None: - return torch.tensor([settings.UKN] * L) + return torch.tensor([config.pytorch.unknown_value] * L) if getattr(getattr(dp, data_type), data_part) is None: - return torch.tensor([settings.UKN] * L) + return torch.tensor([config.pytorch.unknown_value] * L) return _pad(getattr(getattr(dp, data_type), data_part), L, data_type) @@ -62,7 +64,7 @@ def from_dataset_items( batch_data: list, data_type: str, use_error: bool, - structure_padding_value: float = settings.UKN, + structure_padding_value: float = config.pytorch.unknown_value, ): reference = [dp["reference"] for dp in batch_data] length = [dp["length"] for dp in batch_data] @@ -130,7 +132,7 @@ def from_dataset_items( **data, ) - def get(self, data_type: str, index: int = None, to_numpy: bool = False): + def get(self, data_type: str, index: Optional[int] = None, to_numpy: bool = False): if data_type in ["reference", "sequence", "length"]: out = getattr(self, data_type) data_part = None diff --git a/efold/core/datamodule.py b/efold/core/datamodule.py index 8316529..7eac910 100644 --- a/efold/core/datamodule.py +++ b/efold/core/datamodule.py @@ -1,10 +1,10 @@ import datetime -from typing import List, Union +from typing import List, Optional, Union import lightning.pytorch as pl from torch.utils.data import Subset -from efold import settings +from efold.constants import config from efold.core import dataloader, dataset, sampler @@ -25,7 +25,7 @@ def __init__( use_error=False, max_len=None, min_len=None, - structure_padding_value=settings.UKN, + structure_padding_value=config.pytorch.unknown_value, tqdm=True, buckets=None, **kwargs, @@ -100,7 +100,7 @@ def _dataset_merge(self, datasets): merge.collate_fn = collate_fn return merge - def setup(self, stage: str = None): + def setup(self, stage: Optional[str] = None): if stage is None or (stage in ["fit", "predict"] and not hasattr(self, "all_datasets")): self.all_datasets = self._dataset_merge( [ @@ -165,7 +165,7 @@ def _select_test_dataset(self): data_type=[data_type], **self.dataset_args, ) - for data_type, datasets in settings.TEST_SETS.items() + for data_type, datasets in config.test_sets.as_dict.items() if data_type in self.data_type for name in datasets ] diff --git a/efold/core/dataset.py b/efold/core/dataset.py index aa64535..d910588 100644 --- a/efold/core/dataset.py +++ b/efold/core/dataset.py @@ -1,11 +1,11 @@ import os -from typing import List +from typing import List, Optional import numpy as np from rouskinhf import get_dataset from torch.utils.data import Dataset as TorchDataset -from efold import settings +from efold.constants import config from efold.core import batch, datatype, path @@ -21,9 +21,9 @@ def __init__( min_len: int, structure_padding_value: float, use_error: bool, - dms: datatype.DMSDataset = None, - shape: datatype.SHAPEDataset = None, - structure: datatype.StructureDataset = None, + dms: Optional[datatype.DMSDataset] = None, + shape: Optional[datatype.SHAPEDataset] = None, + structure: Optional[datatype.StructureDataset] = None, sort_by_length: bool = False, ) -> None: super().__init__() @@ -95,7 +95,7 @@ def from_local_or_download( use_error: bool = False, max_len=None, min_len=None, - structure_padding_value: float = settings.UKN, + structure_padding_value: float = config.pytorch.unknown_value, sort_by_length: bool = False, tqdm=True, ): diff --git a/efold/core/datatype.py b/efold/core/datatype.py index 007e329..2213f94 100644 --- a/efold/core/datatype.py +++ b/efold/core/datatype.py @@ -1,12 +1,14 @@ +from typing import Optional + import torch -from efold import settings +from efold.constants import config class DataType: attributes = ["true", "pred", "error"] - def __init__(self, true: list, error: list = None, pred: list = None): + def __init__(self, true: list, error: Optional[list] = None, pred: Optional[list] = None): self.true = true self.error = error self.pred = pred @@ -82,14 +84,14 @@ def from_data_json(cls, data_json: dict, L: int, refs: list): values = data_json[ref] if data_type in values: true.append( - torch.tensor(values[data_type], dtype=settings.DTYPE_PER_DATA_TYPE[data_type]) + torch.tensor(values[data_type], dtype=config.data.types_format_torch[data_type]) ) if data_type != "structure": if "error_{}".format(data_type) in values: error.append( torch.tensor( values["error_{}".format(data_type)], - dtype=settings.DTYPE_PER_DATA_TYPE[data_type], + dtype=config.data.types_format_torch[data_type], ) ) else: diff --git a/efold/core/embeddings.py b/efold/core/embeddings.py index daa55e9..dcff337 100644 --- a/efold/core/embeddings.py +++ b/efold/core/embeddings.py @@ -1,28 +1,28 @@ import torch from torch import nn -from efold import settings +from efold.constants import config -NUM_BASES = len(set(settings.seq2int.values())) +NUM_BASES = len(set(config.tokens.seq2int.values())) def sequence_to_int(sequence: str) -> torch.Tensor: - return torch.tensor([settings.seq2int[s] for s in sequence], dtype=torch.int64) + return torch.tensor([config.tokens.seq2int[s] for s in sequence], dtype=torch.int64) def int_to_sequence(sequence: torch.Tensor) -> str: - return "".join([settings.int2seq[i.item()] for i in sequence]) + return "".join([config.tokens.int2seq[i.item()] for i in sequence]) def sequence_to_one_hot(sequence_batch: torch.Tensor) -> torch.Tensor: - return nn.functional.one_hot(sequence_batch, NUM_BASES).type(settings.DEFAULT_FORMAT) + return nn.functional.one_hot(sequence_batch, NUM_BASES).type(torch.float32) def base_pairs_to_pairing_matrix( base_pairs: torch.Tensor, sequence_length: int, padding: int, - pad_value: float = settings.UKN, + pad_value: float = config.pytorch.unknown_value, ) -> torch.Tensor: pairing_matrix = torch.ones((padding, padding)) * pad_value if base_pairs is None: diff --git a/efold/core/loader.py b/efold/core/loader.py index 403df57..6e3548e 100644 --- a/efold/core/loader.py +++ b/efold/core/loader.py @@ -1,6 +1,7 @@ import os from os import listdir, makedirs from os.path import dirname +from typing import Optional import torch @@ -14,7 +15,7 @@ def __init__( makedirs(dirname(self.get_path()), exist_ok=True) @classmethod - def find_best_model(cls, prefix: str) -> "Loader": + def find_best_model(cls, prefix: str) -> Optional["Loader"]: models = [model for model in listdir("models") if model.startswith(prefix)] if len(models) == 0: return None diff --git a/efold/core/metrics.py b/efold/core/metrics.py index 93b7be5..1caa73f 100644 --- a/efold/core/metrics.py +++ b/efold/core/metrics.py @@ -1,7 +1,7 @@ import numpy as np import torch -from efold import settings +from efold.constants import config from efold.core import batch @@ -10,7 +10,7 @@ def mask_and_flatten(func): def wrapped(pred, true): if pred is None or true is None: return np.nan - mask = true != settings.UKN + mask = true != config.pytorch.unknown_value if torch.sum(mask) == 0: return np.nan pred = pred[mask] @@ -133,7 +133,7 @@ def __init__(self, name, data_type=["dms", "shape", "structure"]): def update(self, batch: batch.Batch): for dt in self.data_type: pred, true = batch.get_pairs(dt) - for metric in settings.POSSIBLE_METRICS[dt]: + for metric in config.metrics.possible_metrics[dt]: self._add_metric(dt, metric, metric_factory[metric](pred, true)) return self @@ -141,7 +141,7 @@ def compute(self) -> dict: out: dict = {} for dt in self.data_type: out[dt] = {} - for metric in settings.POSSIBLE_METRICS[dt]: + for metric in config.metrics.possible_metrics[dt]: out[dt][metric] = self._get_nanmean(dt, metric) return out diff --git a/efold/core/model.py b/efold/core/model.py index d8b55f6..2193e48 100644 --- a/efold/core/model.py +++ b/efold/core/model.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from lightning.pytorch.utilities.types import STEP_OUTPUT -from efold import settings +from efold.constants import config from efold.core import batch, metrics, postprocess METRIC_ARGS = dict(dist_sync_on_step=True) @@ -42,13 +42,13 @@ def __init__(self, lr: float, optimizer_fn, weight_data: bool = False, **kwargs) self.weight_data = weight_data self.save_hyperparameters(ignore=["loss_fn"]) - self.lossBCE = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([300])).to(settings.device) + self.lossBCE = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([300])).to(config.device) # Metrics self.metrics_stack = None self.tic = None - self.test_results = {"reference": [], "sequence": [], "structure": []} + self.test_results: dict[str, list[Any]] = {"reference": [], "sequence": [], "structure": []} self.postprocesser = postprocess.Postprocess() @@ -81,7 +81,7 @@ def _loss_signal(self, batch: batch.Batch, data_type: str): ## vv MSE loss vv ## mask = torch.zeros_like(true) - mask[true != settings.UKN] = 1 + mask[true != config.pytorch.unknown_value] = 1 loss = F.mse_loss(pred * mask, true * mask) non_zeros = (mask == 1).sum() / mask.numel() @@ -187,11 +187,11 @@ def test_step(self, batch: batch.Batch, batch_idx: int, dataloader_idx=0): predictions["structure"], batch.get("sequence") ) - from efold import settings + from efold.constants import config self.test_results["reference"] += batch.get("reference") self.test_results["sequence"] += [ - "".join([settings.int2seq[base] for base in seq]) + "".join([config.tokens.int2seq[base] for base in seq]) for seq in batch.get("sequence").detach().tolist() ] self.test_results["structure"] += predictions["structure"].tolist() @@ -204,7 +204,7 @@ def on_test_batch_end( ) -> None: # push the metric directly metric_pack = metrics.MetricsStack( - name=settings.TEST_SETS_NAMES[dataloader_idx], + name=config.test_sets.all_names[dataloader_idx], data_type=self.data_type_output, ) for dt, metric_dict in metric_pack.update(batch).compute().items(): diff --git a/efold/core/postprocess.py b/efold/core/postprocess.py index 30bb39c..0dd6c91 100644 --- a/efold/core/postprocess.py +++ b/efold/core/postprocess.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from scipy.optimize import linear_sum_assignment -from efold import settings +from efold.constants import config class Constraints: @@ -17,7 +17,7 @@ class Constraints: Example: >>> inpt = torch.tensor([[0.1, 0.6, 0.8],[0.6, 0.1, 0.9],[0.8, 0.9, 0.1]]) - >>> sequence = torch.tensor([settings.seq2int[a] for a in "GCU"]) + >>> sequence = torch.tensor([config.tokens.seq2int[a] for a in "GCU"]) >>> out = Constraints().apply_constraints(inpt, sequence=sequence, min_hairpin_length=0, canonical_only=True) >>> assert (out == torch.tensor([[0.0, 0.6, 0.8],[0.6, 0.0, 0.0],[0.8, 0.0, 0.0]])).all(), "The output is not as expected: {}".format(out) @@ -43,7 +43,7 @@ def mask_sharpLoops(self, input_matrix, min_hairpin_length): def mask_nonCanonical(self, sequence): # Embed sequence if isinstance(sequence, str): - sequence = torch.tensor([settings.seq2int[a] for a in sequence]) + sequence = torch.tensor([config.tokens.seq2int[a] for a in sequence]) # make the pairing matrix sequence = sequence.reshape(-1, 1) @@ -52,7 +52,7 @@ def mask_nonCanonical(self, sequence): # find the allowable pairs allowable_pair = set() for pair in ["GU", "GC", "AU"]: - allowable_pair.add(settings.seq2int[pair[0]] + settings.seq2int[pair[1]]) + allowable_pair.add(config.tokens.seq2int[pair[0]] + config.tokens.seq2int[pair[1]]) allowable_pair = torch.tensor(list(allowable_pair), device=pair_of_bases.device) return torch.isin(pair_of_bases, allowable_pair).int() diff --git a/efold/core/sampler.py b/efold/core/sampler.py index 9d5a684..49cd5f8 100644 --- a/efold/core/sampler.py +++ b/efold/core/sampler.py @@ -67,7 +67,7 @@ def __init__( num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, - seed: int = os.environ.get("PL_GLOBAL_SEED", 0), + seed: int = int(os.environ.get("PL_GLOBAL_SEED", 0)), drop_last: bool = False, ) -> None: if num_replicas is None: @@ -166,7 +166,7 @@ def set_epoch(self, epoch: int) -> None: def sampler_factory( dataset: Union[Dataset, Subset], strategy: str, - seed: int = os.environ.get("PL_GLOBAL_SEED", 0), + seed: int = int(os.environ.get("PL_GLOBAL_SEED", 0)), num_replicas: Optional[int] = None, rank: Optional[int] = None, ): diff --git a/efold/core/util.py b/efold/core/util.py index 4072c9b..6da8a48 100644 --- a/efold/core/util.py +++ b/efold/core/util.py @@ -1,15 +1,15 @@ import torch import torch.nn.functional as F -from efold import settings +from efold.constants import config from efold.core import embeddings def _pad(arr: torch.Tensor, L: int, data_type: str) -> torch.Tensor: padding_values = { "sequence": 0, - "dms": settings.UKN, - "shape": settings.UKN, + "dms": config.pytorch.unknown_value, + "shape": config.pytorch.unknown_value, } if data_type == "structure": return embeddings.base_pairs_to_pairing_matrix(arr, L) diff --git a/efold/core/visualisation.py b/efold/core/visualisation.py index 3797cf8..e257c16 100644 --- a/efold/core/visualisation.py +++ b/efold/core/visualisation.py @@ -2,7 +2,7 @@ from matplotlib import pyplot as plt from rouskinhf import int2seq -from efold import settings +from efold.constants import config from efold.core import metrics matplotlib_colors = [ @@ -50,7 +50,7 @@ def known_bases_to_list(x, mask): return x[mask].cpu().numpy() pred, true, sequence = chop_array(pred), chop_array(true), chop_array(sequence) - mask = true != settings.UKN + mask = true != config.pytorch.unknown_value true, pred, sequence = ( known_bases_to_list(true, mask), known_bases_to_list(pred, mask), diff --git a/efold/models/cnn.py b/efold/models/cnn.py index 4e11e4f..e53a20c 100644 --- a/efold/models/cnn.py +++ b/efold/models/cnn.py @@ -83,7 +83,7 @@ def __init__( # nn.Linear(d_model, 1), # ) - def forward(self, batch: batch.Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> dict[str, Tensor]: """ Args: src: Tensor, shape [seq_len, batch_size] diff --git a/efold/models/efold.py b/efold/models/efold.py index bd7338d..d11400b 100644 --- a/efold/models/efold.py +++ b/efold/models/efold.py @@ -34,7 +34,7 @@ def __init__( self.data_type_output = ["structure"] self.lr = lr self.gamma = gamma - self.train_losses = [] + self.train_losses: list[float] = [] self.loss = nn.MSELoss() # Encoder layers @@ -80,7 +80,7 @@ def __init__( ResLayer(dim_in=d_cnn // 2, dim_out=1, n_blocks=4, kernel_size=3, dropout=dropout), ) - def forward(self, batch: batch.Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> dict[str, Tensor]: # Encoding of RNA sequence src = batch.get("sequence") @@ -617,7 +617,6 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): # Basic Residula block self.res_layers = [] for i in range(n_blocks): - # dilation = pow(2, (i % 3)) self.res_layers.append( ResBlock( inplanes=dim_in, @@ -735,7 +734,6 @@ def __init__( self.dropout = nn.Dropout(dropout) self._dropout_rate = dropout - (self.num_heads * self.head_size) ** -0.5 ### DOUBLE CHECK THIS CODE: self.query = nn.Linear(num_heads * head_size, num_heads * head_size, bias=False) self.key = nn.Linear(num_heads * head_size, num_heads * head_size, bias=False) @@ -846,7 +844,6 @@ def __init__(self, kernel_sizes=None, strides=None, **kwargs): super(RelPositionMultiHeadAttention, self).__init__(**kwargs) num_pos_features = self.num_heads * self.head_size - (self.num_heads * self.head_size) ** -0.5 self.pos_kernel = nn.Parameter( torch.rand(self.num_heads, num_pos_features, self.head_size) * 2 - 1 ) @@ -924,10 +921,6 @@ def __init__( self.scale = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) - input_dim**-0.5 - kernel_size**-0.5 - input_dim**-0.5 - self.pw_conv_1 = nn.Conv1d( in_channels=input_dim, out_channels=conv_expansion_rate * input_dim, diff --git a/efold/models/ribonanza.py b/efold/models/ribonanza.py index bed56fd..057a3b8 100644 --- a/efold/models/ribonanza.py +++ b/efold/models/ribonanza.py @@ -2,7 +2,7 @@ from torch import nn from torch.nn import init -from efold import settings +from efold.constants import config from efold.core import model global_gain = 0.1 @@ -166,11 +166,11 @@ def sequence_batch(batch): out.append( torch.concat( [ - torch.tensor([settings.START_TOKEN], dtype=torch.long).to(settings.device), + torch.tensor([config.tokens.start_token], dtype=torch.long).to(config.device), sequence[:length], - torch.tensor([settings.END_TOKEN], dtype=torch.long).to(settings.device), - torch.tensor([settings.PADDING_TOKEN] * (L - length), dtype=torch.long).to( - settings.device + torch.tensor([config.tokens.end_token], dtype=torch.long).to(config.device), + torch.tensor([config.tokens.padding_token] * (L - length), dtype=torch.long).to( + config.device ), ], ) @@ -181,7 +181,7 @@ def structure_batch(batch): structure = batch.get("structure") batch_size, L, _ = structure.shape embedded_matrix = torch.zeros((batch_size, L + 2, L + 2), dtype=torch.float32).to( - settings.device + config.device ) embedded_matrix[:, 1:-1, 1:-1] = structure return embedded_matrix diff --git a/efold/models/transformer.py b/efold/models/transformer.py index 1fbc0e9..13449fb 100644 --- a/efold/models/transformer.py +++ b/efold/models/transformer.py @@ -88,7 +88,7 @@ def __init__( ResLayer(n_blocks=4, dim_in=c_z // 4, dim_out=1, kernel_size=3, dropout=dropout), ) - def forward(self, batch: batch.Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> dict[str, Tensor]: """ Args: src: Tensor, shape [seq_len, batch_size] diff --git a/efold/models/unet.py b/efold/models/unet.py index b398bd6..0e9c7b3 100644 --- a/efold/models/unet.py +++ b/efold/models/unet.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn -from efold import settings +from efold.constants import config from efold.core import batch, model CH_FOLD2 = 1 @@ -39,7 +39,7 @@ def __init__( self.Conv_1x1 = nn.Conv2d(int(32 * CH_FOLD2), output_ch, kernel_size=1, stride=1, padding=0) - def forward(self, batch: batch.Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> dict[str, Tensor]: src = batch.get("sequence") padd_multiple = 32 @@ -111,7 +111,7 @@ def seq2map(self, seq_int): # take integer encoded sequence and return last channel of embedding (pairing energy) def creatmat(data, device=None): with torch.no_grad(): - data = "".join([settings.int2seq[d] for d in data.tolist()]) + data = "".join([config.tokens.int2seq[d] for d in data.tolist()]) paired = defaultdict( float, {"AU": 2.0, "UA": 2.0, "GC": 3.0, "CG": 3.0, "UG": 0.8, "GU": 0.8} ) diff --git a/efold/settings.py b/efold/settings.py deleted file mode 100644 index 4e5208e..0000000 --- a/efold/settings.py +++ /dev/null @@ -1,56 +0,0 @@ -from pathlib import Path -from typing import Any - -import torch -import yaml -from torch import cuda, float32 - -_settings_path = Path(__file__).parent / "settings.yaml" - - -def _load_settings() -> dict[str, Any]: - """ - Load settings from the YAML configuration file. - - :return: Dictionary containing all settings - """ - with open(_settings_path, "r") as f: - return yaml.safe_load(f) - - -_config = _load_settings() - -seq2int = _config["seq2int"] -int2seq = {v: k for k, v in seq2int.items()} - -START_TOKEN = _config["start_token"] -END_TOKEN = _config["end_token"] -PADDING_TOKEN = seq2int[_config["padding_token_key"]] - -DEFAULT_FORMAT = float32 -torch.set_default_dtype(DEFAULT_FORMAT) -UKN = _config["unknown_value"] -VAL_GU = _config["val_gu"] - -device = "cuda" if cuda.is_available() else "cpu" - -TEST_SETS = _config["test_sets"] -TEST_SETS_NAMES = [i for j in TEST_SETS.values() for i in j] -DATA_TYPES_TEST_SETS = [k for k, v in TEST_SETS.items() for i in v] - -DATA_TYPES = _config["data_types"] - -_dtype_mapping = { - "float32": torch.float32, - "int32": torch.int32, -} - -DATA_TYPES_FORMAT = {k: _dtype_mapping[v] for k, v in _config["data_types_format"].items()} - -REFERENCE_METRIC = _config["reference_metric"] -REF_METRIC_SIGN = _config["ref_metric_sign"] -POSSIBLE_METRICS = _config["possible_metrics"] - -DTYPE_PER_DATA_TYPE = DATA_TYPES_FORMAT - -torch.set_default_dtype(torch.float32) diff --git a/efold/settings.yaml b/efold/settings.yaml deleted file mode 100644 index 07436e5..0000000 --- a/efold/settings.yaml +++ /dev/null @@ -1,66 +0,0 @@ -# Token mappings -seq2int: - X: 0 - A: 1 - C: 2 - G: 3 - U: 4 - -# Special tokens -start_token: null -end_token: null -padding_token_key: X - -# PyTorch configuration -default_format: float32 -unknown_value: -1000.0 -val_gu: 0.095 - -# Test sets -test_sets: - structure: - - PDB - - archiveII - - lncRNA_nonFiltered - - viral_fragments - sequence: [] - dms: [] - shape: [] - -# Data types -data_types: - - structure - - dms - - shape - -# Data type formats -data_types_format: - structure: int32 - dms: float32 - shape: float32 - -# Reference metrics -reference_metric: - structure: f1 - dms: mae - shape: mae - -# Metric signs (1 for higher is better, -1 for lower is better) -ref_metric_sign: - structure: 1 - dms: -1 - shape: -1 - -# Possible metrics per data type -possible_metrics: - structure: - - f1 - dms: - - mae - - r2 - - pearson - shape: - - mae - - r2 - - pearson - diff --git a/efold/util/format_conversion.py b/efold/util/format_conversion.py index 7882433..8bbb61c 100644 --- a/efold/util/format_conversion.py +++ b/efold/util/format_conversion.py @@ -1,7 +1,8 @@ # this code wsa taken from arnie_utils.py +from typing import Optional -def convert_bp_list_to_dotbracket(bp_list: list[tuple[int, int]], seq_len: int) -> str: +def convert_bp_list_to_dotbracket(bp_list: list[tuple[int, int]], seq_len: int) -> Optional[str]: bp_list = [(b - 1, c - 1) for b, c in bp_list] db = "." * seq_len diff --git a/scripts/cnn_template.py b/scripts/cnn_template.py index 1b6a4aa..f54abf6 100644 --- a/scripts/cnn_template.py +++ b/scripts/cnn_template.py @@ -12,7 +12,7 @@ from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.strategies import DDPStrategy -from efold import settings +from efold.constants import config from efold.core import callbacks, datamodule from efold.models import factory @@ -27,7 +27,7 @@ n_gpu = 8 USE_WANDB = 0 STRATEGY = "random" - print("Running on device: {}".format(settings.device)) + print("Running on device: {}".format(config.device)) if USE_WANDB: project = "Structure-classic" wandb_logger = WandbLogger(project=project) @@ -68,7 +68,7 @@ model.load_state_dict( torch.load( "/Users/alberic/Desktop/lively-waterfall-8_epoch45.pt", - map_location=torch.device(settings.device), + map_location=torch.device(config.device), ) ) @@ -76,7 +76,7 @@ wandb_logger.watch(model, log="all") trainer = Trainer( - accelerator=settings.device, + accelerator=config.device, devices=n_gpu if STRATEGY == "ddp" else 1, strategy=DDPStrategy(find_unused_parameters=False) if STRATEGY == "ddp" else "auto", # precision="16-mixed", diff --git a/scripts/efold_training.py b/scripts/efold_training.py index 3943db9..abe9d72 100644 --- a/scripts/efold_training.py +++ b/scripts/efold_training.py @@ -9,7 +9,7 @@ from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.strategies import DDPStrategy -from efold import settings +from efold.constants import config from efold.core import callbacks, datamodule from efold.models import factory @@ -19,7 +19,7 @@ STRATEGY = "random" n_gpu = 1 - print("Running on device: {}".format(settings.device)) + print("Running on device: {}".format(config.device)) if USE_WANDB: wandb_logger = WandbLogger(project="test") @@ -58,7 +58,7 @@ wandb_logger.watch(model, log="all") trainer = Trainer( - accelerator=settings.device, + accelerator=config.device, devices=n_gpu if STRATEGY == "ddp" else 1, strategy=DDPStrategy(find_unused_parameters=False) if STRATEGY == "ddp" else "auto", max_epochs=15, diff --git a/scripts/mlp-template.py b/scripts/mlp-template.py index 8544c6f..f652db5 100644 --- a/scripts/mlp-template.py +++ b/scripts/mlp-template.py @@ -5,7 +5,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.loggers import WandbLogger -from efold import settings +from efold.constants import config from efold.core import callbacks, datamodule from efold.models import factory @@ -18,7 +18,7 @@ # why do you need this? if __name__ == "__main__": - print("Running on device: {}".format(settings.device)) + print("Running on device: {}".format(config.device)) BATCH_SIZE = 4 LR = 5e-5 @@ -71,7 +71,7 @@ max_epochs=1, log_every_n_steps=1, logger=wandb_logger, - accelerator=settings.device, + accelerator=config.device, callbacks=[ callbacks.ModelCheckpoint(every_n_epoch=1), ], diff --git a/scripts/ribonanza-template.py b/scripts/ribonanza-template.py index e08caa8..79e04f4 100644 --- a/scripts/ribonanza-template.py +++ b/scripts/ribonanza-template.py @@ -10,13 +10,13 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from efold import settings +from efold.constants import config from efold.core import callbacks, datamodule from efold.models import factory if __name__ == "__main__": USE_WANDB = True - print("Running on device: {}".format(settings.device)) + print("Running on device: {}".format(config.device)) if USE_WANDB: wandb_logger = WandbLogger(project="ribonanza-solution", name="first-run") @@ -62,7 +62,7 @@ devices=8, strategy=DDPStrategy(find_unused_parameters=True), max_epochs=1000, - accelerator=settings.device, + accelerator=config.device, logger=wandb_logger if USE_WANDB else None, callbacks=[ LearningRateMonitor(logging_interval="epoch"), diff --git a/scripts/transformer-template.py b/scripts/transformer-template.py index f4dca33..cc914ee 100644 --- a/scripts/transformer-template.py +++ b/scripts/transformer-template.py @@ -7,7 +7,7 @@ from lightning.pytorch.callbacks import LearningRateMonitor from lightning.pytorch.loggers import WandbLogger -from efold import settings +from efold.constants import config from efold.core import callbacks, datamodule from efold.models import factory @@ -18,7 +18,7 @@ if __name__ == "__main__": USE_WANDB = True - print("Running on device: {}".format(settings.device)) + print("Running on device: {}".format(config.device)) if USE_WANDB: wandb_logger = WandbLogger(project="CHANGE_ME", name="debug") @@ -65,7 +65,7 @@ # train with both splits trainer = Trainer( - accelerator=settings.device, + accelerator=config.device, # devices=4, # strategy="ddp", # precision="16-mixed", @@ -97,7 +97,7 @@ ) trainer = Trainer( - accelerator=settings.device, + accelerator=config.device, devices=1, callbacks=[ # don't change this diff --git a/scripts/unet_training.py b/scripts/unet_training.py index 4365743..b5a7733 100644 --- a/scripts/unet_training.py +++ b/scripts/unet_training.py @@ -9,7 +9,7 @@ from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.strategies import DDPStrategy -from efold import settings +from efold.constants import config from efold.core import callbacks, datamodule from efold.models import factory @@ -17,7 +17,7 @@ if __name__ == "__main__": USE_WANDB = False STRATEGY = "random" - print("Running on device: {}".format(settings.device)) + print("Running on device: {}".format(config.device)) if USE_WANDB: wandb_logger = WandbLogger(project="test") @@ -49,7 +49,7 @@ wandb_logger.watch(model, log="all") trainer = Trainer( - accelerator=settings.device, + accelerator=config.device, devices=8 if STRATEGY == "ddp" else 1, strategy=DDPStrategy(find_unused_parameters=False) if STRATEGY == "ddp" else "auto", max_epochs=15, diff --git a/setup.py b/setup.py index 7c9b84b..26fab4c 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,6 @@ python_requires=">=3.10", py_modules=["efold"], include_package_data=True, - package_data={"": ["resources/*.pt", "settings.yaml"]}, + package_data={"": ["resources/*.pt"]}, packages=find_packages(), )