diff --git a/README.md b/README.md index d58641b..7745013 100644 --- a/README.md +++ b/README.md @@ -50,8 +50,8 @@ efold -h ### Using python ```python ->>> from efold import inference ->>> inference('AAACAUGAGGAUUACCCAUGU', fmt='dotbracket') +>>> from efold.api import run +>>> run.run('AAACAUGAGGAUUACCCAUGU', fmt='dotbracket') ..(((((.((....))))))) ``` diff --git a/efold/__init__.py b/efold/__init__.py index d99bd05..e69de29 100644 --- a/efold/__init__.py +++ b/efold/__init__.py @@ -1,5 +0,0 @@ -from .models import create_model -from .core import * -from .util import * -from .config import * -from .api import * \ No newline at end of file diff --git a/efold/api/__init__.py b/efold/api/__init__.py index 35095cd..e69de29 100644 --- a/efold/api/__init__.py +++ b/efold/api/__init__.py @@ -1 +0,0 @@ -from .run import run as inference \ No newline at end of file diff --git a/efold/api/run.py b/efold/api/run.py index f21cddf..69ad475 100644 --- a/efold/api/run.py +++ b/efold/api/run.py @@ -1,19 +1,20 @@ import os -from typing import List, Union -from ..models import create_model -import torch -from os.path import join, dirname -from ..core import batch -from ..core.embeddings import sequence_to_int -from ..core.postprocess import Postprocess +from os.path import dirname, join +from typing import List, Optional, Union + import numpy as np -from ..util.format_conversion import convert_bp_list_to_dotbracket +import torch + +from efold.core import batch, embeddings, postprocess +from efold.models import factory +from efold.util import format_conversion torch.set_default_dtype(torch.float32) -postprocesser = Postprocess() +postprocesser = postprocess.Postprocess() -def _load_sequences_from_fasta(fasta:str): + +def _load_sequences_from_fasta(fasta: str) -> list[str]: with open(fasta, "r") as f: lines = f.readlines() sequences = [] @@ -24,36 +25,44 @@ def _load_sequences_from_fasta(fasta:str): sequences[-1] += line.strip() return sequences -def _predict_structure(model, sequence:str, device='cpu'): - seq = sequence_to_int(sequence).unsqueeze(0) +def _predict_structure(model, sequence: str, device: str = "cpu") -> list[tuple[int, int]]: + seq = embeddings.sequence_to_int(sequence).unsqueeze(0) b = batch.Batch( sequence=seq, reference=[""], length=[len(seq)], - L = len(seq), + L=len(seq), use_error=False, batch_size=1, data_types=["sequence"], - dt_count={"sequence": 1}).to(device) - + dt_count={"sequence": 1}, + ).to(device) + # predict the structure with torch.inference_mode(): pred = model(b) - structure = postprocesser.run(pred['structure'].to('cpu'), b.get('sequence').to('cpu')).numpy().round()[0] + structure = ( + postprocesser.run(pred["structure"].to("cpu"), b.get("sequence").to("cpu")) + .numpy() + .round()[0] + ) # turn into 1-indexed base pairs - return [(b,c) for b, c in (np.stack(np.where(np.triu(structure) == 1)) + 1).T] + return [(b, c) for b, c in (np.stack(np.where(np.triu(structure) == 1)) + 1).T] + -def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): +def run( + arg: Optional[Union[str, List[str]]] = None, fmt: str = "dotbracket", device: Optional[str] = None +) -> dict[str, Union[str, list[tuple[int, int]]]]: """Runs the Efold API on the provided sequence or fasta file. - + Args: - arg (str): The sequence or the list of sequences to run Efold on, or the path to a fasta file containing the sequences. - + arg (str): The sequence or the list of sequences to run Efold on, or the path to a fasta file containing the sequences. + Returns: dict: A dictionary containing the sequences as keys and the predicted secondary structures as values. - + Examples: >>> from efold.api.run import run >>> structure = run("GGGAAAUCC") # this is awful, we need to remove the prints @@ -66,9 +75,11 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): No scaling, use preLN Replace GLU with swish for Conv >>> assert structure == {'GGGAAAUCC': [(1, 9), (2, 8)]}, "Test failed: {}".format(structure) - + """ - assert fmt in ["dotbracket", "basepair", 'bp'], "Invalid format. Must be either 'dotbracket' or 'basepair'" + assert fmt in ["dotbracket", "basepair", "bp"], ( + "Invalid format. Must be either 'dotbracket' or 'basepair'" + ) # Check if the input is valid if not arg: @@ -77,13 +88,13 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): if not os.path.exists(arg): raise ValueError("File not found") sequences = _load_sequences_from_fasta(arg) - elif type(arg) == str: + elif isinstance(arg, str): sequences = [arg] elif hasattr(arg, "__iter__") and all([isinstance(s, str) for s in arg]): sequences = arg else: raise ValueError("Either sequence or fasta must be provided") - + # Get device if not device: if torch.cuda.is_available(): @@ -92,7 +103,7 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): device = torch.device("cpu") # Load best model - model = create_model( + model = factory.create_model( model="efold", ntoken=5, d_model=64, @@ -105,17 +116,19 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None): weight_decay=0, gamma=0.995, ) - model.load_state_dict(torch.load(join(dirname(dirname(__file__)), "resources/efold_weights.pt")), strict=False) + model.load_state_dict( + torch.load(join(dirname(dirname(__file__)), "resources/efold_weights.pt")), strict=False + ) model.eval() model = model.to(device) structures = [] - for seq in sequences: + for seq in sequences: structure = _predict_structure(model, seq, device=device) if fmt == "dotbracket": - db_structure = convert_bp_list_to_dotbracket(structure, len(seq)) - if db_structure != None: + db_structure = format_conversion.convert_bp_list_to_dotbracket(structure, len(seq)) + if db_structure is not None: structure = db_structure structures.append(structure) - return {seq: structure for seq, structure in zip(sequences, structures)} \ No newline at end of file + return {seq: structure for seq, structure in zip(sequences, structures)} diff --git a/efold/cli.py b/efold/cli.py index afb737b..2c626f0 100644 --- a/efold/cli.py +++ b/efold/cli.py @@ -1,34 +1,43 @@ import json + import click -from efold.api.run import run - -@click.command('efold') -@click.argument('sequence', required=False, type=str) -@click.option('--fasta', '-f', help='Input FASTA file path') -@click.option('--output', '-o', default='output.txt', help='Output file path (json, txt or csv)', type=click.Path()) -@click.option('--basepair/--dotbracket', '-bp/-db', default=False, help='Output structure format') -@click.option('--help', '-h', is_flag=True, help='Show this message', type=bool) -def cli(sequence, fasta, output, basepair, help): - + +from efold.api import run + + +@click.command("efold") +@click.argument("sequence", required=False, type=str) +@click.option("--fasta", "-f", help="Input FASTA file path") +@click.option( + "--output", + "-o", + default="output.txt", + help="Output file path (json, txt or csv)", + type=click.Path(), +) +@click.option("--basepair/--dotbracket", "-bp/-db", default=False, help="Output structure format") +@click.option("--help", "-h", is_flag=True, help="Show this message", type=bool) +def cli(sequence: str, fasta: str, output: str, basepair: bool, help: bool) -> None: if help: click.echo(cli.get_help(click.Context(cli))) return - - fmt = 'bp' if basepair else 'dotbracket' + + fmt = "bp" if basepair else "dotbracket" if sequence: - result = run(sequence, fmt) + result = run.run(sequence, fmt) elif fasta: - result = run(fasta, fmt) + result = run.run(fasta, fmt) else: click.echo("Please provide either a sequence or a FASTA file.") return - with open(output, 'w') as f: - file_fmt = output.split('.')[-1] - if file_fmt == 'json': + with open(output, "w") as f: + file_fmt = output.split(".")[-1] + if file_fmt == "json": f.write(json.dumps(result, indent=4)) - elif file_fmt == 'csv': + elif file_fmt == "csv": import csv + writer = csv.writer(f) writer.writerows(result.items()) else: @@ -40,5 +49,6 @@ def cli(sequence, fasta, output, basepair, help): click.echo() click.echo(f"Output saved to {output}") -if __name__ == '__main__': - cli() \ No newline at end of file + +if __name__ == "__main__": + cli() diff --git a/efold/config.py b/efold/config.py deleted file mode 100644 index 3387524..0000000 --- a/efold/config.py +++ /dev/null @@ -1,55 +0,0 @@ -from torch import float32, cuda, backends -import torch - -seq2int = {"X": 0, "A": 1, "C": 2, "G": 3, "U": 4} # , 'S': 5, 'E': 6} -# seq2int = {"X": 0, "A": 1, "U": 2, "C": 3, "G": 4} -int2seq = {v: k for k, v in seq2int.items()} - -START_TOKEN = None # seq2int['S'] -END_TOKEN = None # seq2int['E'] -PADDING_TOKEN = seq2int["X"] - -DEFAULT_FORMAT = float32 -torch.set_default_dtype(DEFAULT_FORMAT) -UKN = -1000.0 -VAL_GU = 0.095 -device = ( - "cuda" - if cuda.is_available() - # else "mps" # moi j'aime bien le mps - # if backends.mps.is_available() - else "cpu" -) - -TEST_SETS = { - "structure": ["PDB", "archiveII", "lncRNA_nonFiltered", "viral_fragments"], - "sequence": [], - "dms": [], - "shape": [], -} - - -TEST_SETS_NAMES = [i for j in TEST_SETS.values() for i in j] -DATA_TYPES_TEST_SETS = [k for k, v in TEST_SETS.items() for i in v] - -DATA_TYPES = ["structure", "dms", "shape"] -DATA_TYPES_FORMAT = { - "structure": torch.int32, - "dms": DEFAULT_FORMAT, - "shape": DEFAULT_FORMAT, -} -REFERENCE_METRIC = {"structure": "f1", "dms": "mae", "shape": "mae"} -REF_METRIC_SIGN = {"structure": 1, "dms": -1, "shape": -1} -POSSIBLE_METRICS = { - "structure": ["f1"], # , "mFMI"], - "dms": ["mae", "r2", "pearson"], - "shape": ["mae", "r2", "pearson"], -} - -DTYPE_PER_DATA_TYPE = { - "structure": torch.int32, - "dms": DEFAULT_FORMAT, - "shape": DEFAULT_FORMAT, -} - -torch.set_default_dtype(torch.float32) diff --git a/efold/constants.py b/efold/constants.py new file mode 100644 index 0000000..e0a6d5c --- /dev/null +++ b/efold/constants.py @@ -0,0 +1,199 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import torch +import yaml +from torch import cuda + + +@dataclass +class TokenMapping: + """Token mappings for nucleotides. + + :param seq2int: Mapping from sequence characters to integers + :param start_token: Optional start token + :param end_token: Optional end token + :param padding_token_key: Key used for padding token + """ + seq2int: dict[str, int] + start_token: Optional[str] + end_token: Optional[str] + padding_token_key: str + + @property + def int2seq(self) -> dict[int, str]: + """Reverse mapping from integers to sequences. + + :return: Dictionary mapping integers to sequence characters + """ + return {v: k for k, v in self.seq2int.items()} + + @property + def padding_token(self) -> int: + """Get the padding token integer value. + + :return: Integer value of padding token + """ + return self.seq2int[self.padding_token_key] + + +@dataclass +class PyTorchConfig: + """PyTorch configuration settings. + + :param unknown_value: Value used for unknown/missing data + :param val_gu: Value for GU base pairing + """ + unknown_value: float + val_gu: float + + +@dataclass +class TestSets: + """Test set definitions. + + :param structure: List of structure test sets + :param sequence: List of sequence test sets + :param dms: List of DMS test sets + :param shape: List of SHAPE test sets + """ + structure: list[str] + sequence: list[str] + dms: list[str] + shape: list[str] + + @property + def as_dict(self) -> dict[str, list[str]]: + """Return test sets as dictionary. + + :return: Dictionary of test set types to names + """ + return { + "structure": self.structure, + "sequence": self.sequence, + "dms": self.dms, + "shape": self.shape + } + + @property + def all_names(self) -> list[str]: + """Get all test set names flattened. + + :return: List of all test set names + """ + return [name for sets in self.as_dict.values() for name in sets] + + @property + def data_types_per_test_set(self) -> list[str]: + """Get data type for each test set. + + :return: List of data types matching test sets + """ + return [dtype for dtype, names in self.as_dict.items() for _ in names] + + +@dataclass +class DataConfig: + """Data types and format configuration. + + :param types: List of available data types + :param types_format: Mapping of data types to their torch dtype + """ + types: list[str] + types_format: dict[str, str] + + @property + def types_format_torch(self) -> dict[str, torch.dtype]: + """Get data types as torch dtypes. + + :return: Dictionary mapping data types to torch.dtype objects + """ + dtype_map = { + "int32": torch.int32, + "float32": torch.float32, + "float64": torch.float64, + } + return {k: dtype_map[v] for k, v in self.types_format.items()} + + +@dataclass +class MetricsConfig: + """Metrics configuration. + + :param reference_metric: Primary metric for each data type + :param ref_metric_sign: Sign convention for metrics (1=higher is better, -1=lower is better) + :param possible_metrics: List of available metrics per data type + """ + reference_metric: dict[str, str] + ref_metric_sign: dict[str, int] + possible_metrics: dict[str, list[str]] + + +@dataclass +class Config: + """Main configuration class combining all config sections. + + :param tokens: Token mapping configuration + :param pytorch: PyTorch-specific configuration + :param test_sets: Test set definitions + :param data: Data type configuration + :param metrics: Metrics configuration + """ + tokens: TokenMapping + pytorch: PyTorchConfig + test_sets: TestSets + data: DataConfig + metrics: MetricsConfig + device: str = field(init=False) + + def __post_init__(self) -> None: + """Initialize device after other fields are set. + + :return: None + """ + object.__setattr__(self, 'device', "cuda" if cuda.is_available() else "cpu") + + +def _load_config() -> Config: + """Load configuration from YAML file. + + :return: Instantiated Config object + """ + config_path = Path(__file__).parent / "constants.yaml" + + with open(config_path, 'r') as f: + data = yaml.safe_load(f) + + return Config( + tokens=TokenMapping( + seq2int=data['tokens']['seq2int'], + start_token=data['tokens']['start_token'], + end_token=data['tokens']['end_token'], + padding_token_key=data['tokens']['padding_token_key'] + ), + pytorch=PyTorchConfig( + unknown_value=data['pytorch']['unknown_value'], + val_gu=data['pytorch']['val_gu'] + ), + test_sets=TestSets( + structure=data['test_sets']['structure'], + sequence=data['test_sets']['sequence'], + dms=data['test_sets']['dms'], + shape=data['test_sets']['shape'] + ), + data=DataConfig( + types=data['data']['types'], + types_format=data['data']['types_format'] + ), + metrics=MetricsConfig( + reference_metric=data['metrics']['reference_metric'], + ref_metric_sign=data['metrics']['ref_metric_sign'], + possible_metrics=data['metrics']['possible_metrics'] + ) + ) + + +config = _load_config() + +torch.set_default_dtype(torch.float32) diff --git a/efold/constants.yaml b/efold/constants.yaml new file mode 100644 index 0000000..f098739 --- /dev/null +++ b/efold/constants.yaml @@ -0,0 +1,62 @@ +# RNA Secondary Structure Prediction Configuration + +tokens: + seq2int: + X: 0 + A: 1 + C: 2 + G: 3 + U: 4 + + start_token: null + end_token: null + padding_token_key: X + +pytorch: + unknown_value: -1000.0 + val_gu: 0.095 + +test_sets: + structure: + - PDB + - archiveII + - lncRNA_nonFiltered + - viral_fragments + sequence: [] + dms: [] + shape: [] + +data: + types: + - structure + - dms + - shape + + types_format: + structure: int32 + dms: float32 + shape: float32 + +metrics: + reference_metric: + structure: f1 + dms: mae + shape: mae + + ref_metric_sign: + structure: 1 + dms: -1 + shape: -1 + + possible_metrics: + structure: + - f1 + dms: + - mae + - r2 + - pearson + shape: + - mae + - r2 + - pearson + diff --git a/efold/core/__init__.py b/efold/core/__init__.py index 0f07848..e69de29 100644 --- a/efold/core/__init__.py +++ b/efold/core/__init__.py @@ -1,4 +0,0 @@ -from .datamodule import DataModule -from . import metrics -from .dataset import Dataset -from .postprocess import Postprocess diff --git a/efold/core/batch.py b/efold/core/batch.py index 6ae0614..f84cf84 100644 --- a/efold/core/batch.py +++ b/efold/core/batch.py @@ -1,33 +1,31 @@ +from typing import Optional + import torch -from torch import tensor import torch.nn.functional as F -from .embeddings import base_pairs_to_pairing_matrix, sequence_to_int -from ..config import device, POSSIBLE_METRICS, UKN -from typing import Dict -from .datatype import data_type_factory -from .util import split_data_type -from torch import cuda, backends + +from efold.constants import config +from efold.core import datatype, embeddings, util -def _pad(arr, L, data_type, accept_none=False): +def _pad(arr: torch.Tensor, L: int, data_type: str, accept_none: bool = False) -> torch.Tensor: padding_values = { "sequence": 0, - "dms": UKN, - "shape": UKN, + "dms": config.pytorch.unknown_value, + "shape": config.pytorch.unknown_value, } - assert ( - data_type in padding_values.keys() - ), f"Unknown data type {data_type}. If you want to pad a structure, use base_pairs_to_pairing_matrix." + assert data_type in padding_values.keys(), ( + f"Unknown data type {data_type}. If you want to pad a structure, use base_pairs_to_pairing_matrix." + ) if accept_none and arr is None: - return tensor([padding_values[data_type]] * L) + return torch.tensor([padding_values[data_type]] * L) return F.pad(arr, (0, L - len(arr)), value=padding_values[data_type]) -def get_padded_vector(dp, data_type, data_part, L): +def get_padded_vector(dp: dict, data_type: str, data_part: str, L: int) -> torch.Tensor: if getattr(dp, data_type) is None: - return tensor([UKN] * L) + return torch.tensor([config.pytorch.unknown_value] * L) if getattr(getattr(dp, data_type), data_part) is None: - return tensor([UKN] * L) + return torch.tensor([config.pytorch.unknown_value] * L) return _pad(getattr(getattr(dp, data_type), data_part), L, data_type) @@ -45,7 +43,7 @@ def __init__( dms=None, shape=None, structure=None, - device = 'cpu' + device="cpu", ): self.reference = reference self.sequence = sequence @@ -66,7 +64,7 @@ def from_dataset_items( batch_data: list, data_type: str, use_error: bool, - structure_padding_value: float = UKN, + structure_padding_value: float = config.pytorch.unknown_value, ): reference = [dp["reference"] for dp in batch_data] length = [dp["length"] for dp in batch_data] @@ -74,7 +72,7 @@ def from_dataset_items( # move the conversion to the dataset sequence = torch.stack( - [_pad(sequence_to_int(dp["sequence"]), L, "sequence") for dp in batch_data] + [_pad(embeddings.sequence_to_int(dp["sequence"]), L, "sequence") for dp in batch_data] ) batch_size = len(reference) @@ -91,16 +89,16 @@ def from_dataset_items( } for dt in data_type: if dt == "structure": - data[dt] = data_type_factory["batch"][dt]( + data[dt] = datatype.data_type_factory["batch"][dt]( true=torch.stack( [ - base_pairs_to_pairing_matrix( + embeddings.base_pairs_to_pairing_matrix( dp["structure"]["true"], - l, + len_, padding=L, pad_value=structure_padding_value, ) - for (dp, l) in zip(batch_data, length) + for (dp, len_) in zip(batch_data, length) ] ), error=None, @@ -113,16 +111,14 @@ def from_dataset_items( true = torch.stack(true) # use error if there's a single non-None error and if the true signal is not None - if use_error and len( - [1 for dp in batch_data if dp[dt]["error"] is not None] - ): + if use_error and len([1 for dp in batch_data if dp[dt]["error"] is not None]): for dp in batch_data: error.append(_pad(dp[dt]["error"], L, dt, accept_none=True)) error = torch.stack(error) else: error = [None] * batch_size - data[dt] = data_type_factory["batch"][dt](true=true, error=error) + data[dt] = datatype.data_type_factory["batch"][dt](true=true, error=error) return cls( reference=reference, @@ -136,12 +132,12 @@ def from_dataset_items( **data, ) - def get(self, data_type, index=None, to_numpy=False): + def get(self, data_type: str, index: Optional[int] = None, to_numpy: bool = False): if data_type in ["reference", "sequence", "length"]: out = getattr(self, data_type) data_part = None else: - data_part, data_type = split_data_type(data_type) + data_part, data_type = util.split_data_type(data_type) # could be in the dataset but wasn't requested in the dm init if data_type not in self.data_types: @@ -157,9 +153,9 @@ def get(self, data_type, index=None, to_numpy=False): if index is not None: out = out[index] if hasattr(out, "__len__"): - l = self.get("length")[index] + len_ = self.get("length")[index] if data_type == "structure": - out = out[:l, :l] + out = out[:len_, :len_] else: out = out[: self.get("length")[index]] @@ -168,7 +164,7 @@ def get(self, data_type, index=None, to_numpy=False): out = out.squeeze().cpu().numpy() return out - def integrate_prediction(self, prediction): + def integrate_prediction(self, prediction: dict) -> None: for data_type, pred in prediction.items(): if getattr(self, data_type) is not None: getattr(self, data_type).pred = pred @@ -176,44 +172,39 @@ def integrate_prediction(self, prediction): setattr( self, data_type, - data_type_factory["batch"][data_type](true=None, pred=pred), + datatype.data_type_factory["batch"][data_type](true=None, pred=pred), ) - def get_pairs(self, data_type, to_numpy=False): + def get_pairs(self, data_type: str, to_numpy: bool = False) -> tuple: return ( self.get("pred_{}".format(data_type), to_numpy=to_numpy), self.get("true_{}".format(data_type), to_numpy=to_numpy), ) - def count(self, data_type): + def count(self, data_type: str) -> int: if data_type in ["reference", "sequence", "length"]: return self.batch_size - if not data_type in self.dt_count or getattr(self, data_type) is None: + if data_type not in self.dt_count or getattr(self, data_type) is None: return 0 return self.dt_count[data_type] - def contains(self, data_type): + def contains(self, data_type: str) -> bool: if data_type in ["reference", "sequence", "length"]: return True - data_part, data_type = split_data_type(data_type) + data_part, data_type = util.split_data_type(data_type) if not self.count(data_type): return False if ( - not hasattr( - getattr(self, data_type), data_part - ) # that's more of a sanity check + not hasattr(getattr(self, data_type), data_part) # that's more of a sanity check or getattr(getattr(self, data_type), data_part) is None ): return False return True - def __len__(self): + def __len__(self) -> int: return self.count("sequence") - - # return out - - def __del__(self): + def __del__(self) -> None: del self.dms del self.shape del self.structure @@ -225,27 +216,26 @@ def __del__(self): del self.data_types del self.dt_count del self - + @property - def device(self): + def device(self) -> str: return self._device @device.getter - def device(self): + def device(self) -> str: return self._device @device.setter - def device(self, device): - # assert device exists - if device == 'mps' and not backends.mps.is_available(): + def device(self, device: str) -> None: + if device == "mps" and not torch.backends.mps.is_available(): raise ValueError("MPS is not available on this device.") - if device == 'cuda' and not cuda.is_available(): + if device == "cuda" and not torch.cuda.is_available(): raise ValueError("CUDA is not available on this device.") - for attr in ['dms', 'shape', 'structure', 'sequence']: + for attr in ["dms", "shape", "structure", "sequence"]: if getattr(self, attr) is not None: setattr(self, attr, getattr(self, attr).to(device)) self._device = device - def to(self, device): + def to(self, device: str) -> "Batch": self.device = device return self diff --git a/efold/core/callbacks.py b/efold/core/callbacks.py index 5593d0a..929c12d 100644 --- a/efold/core/callbacks.py +++ b/efold/core/callbacks.py @@ -1,27 +1,9 @@ -import os -from lightning import LightningModule, Trainer import lightning.pytorch as pl -import torch -import numpy as np -import pandas as pd -from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.utilities import rank_zero_only import wandb -from typing import Any +from lightning.pytorch import Trainer +from lightning.pytorch.utilities import rank_zero_only -from .visualisation import plot_factory -from .metrics import metric_factory -from .datamodule import DataModule -from .loader import Loader -from .batch import Batch -from ..config import ( - TEST_SETS_NAMES, - REF_METRIC_SIGN, - REFERENCE_METRIC, - DATA_TYPES_TEST_SETS, - POSSIBLE_METRICS, -) -from .logger import Logger, LocalLogger +from efold.core import loader class ModelCheckpoint(pl.Callback): @@ -42,6 +24,6 @@ def on_validation_end(self, trainer: Trainer, pl_module, dataloader_idx=0): return name = "{}_epoch{}.pt".format(wandb.run.name, trainer.current_epoch) - loader = Loader(path="models/" + name) + loader_obj = loader.Loader(path="models/" + name) # logs what MAE it corresponds to - loader.dump(pl_module) + loader_obj.dump(pl_module) diff --git a/efold/core/dataloader.py b/efold/core/dataloader.py index 705c664..4ac7be9 100644 --- a/efold/core/dataloader.py +++ b/efold/core/dataloader.py @@ -1,6 +1,8 @@ -from torch.utils.data import DataLoader as _DataLoader import torch -from .batch import Batch +from torch.utils.data import DataLoader as _DataLoader + +from efold.core import batch + class DataLoader(_DataLoader): def __init__(self, *args, **kwargs): @@ -8,7 +10,9 @@ def __init__(self, *args, **kwargs): self.to_device = kwargs.pop("to_device") super().__init__(*args, **kwargs) - def transfer_batch_to_device(self, batch: Batch, device: torch.device, dataloader_idx: int) -> Batch: + def transfer_batch_to_device( + self, batch: batch.Batch, device: torch.device, dataloader_idx: int + ) -> batch.Batch: if self.to_device: return batch.to(device) return batch diff --git a/efold/core/datamodule.py b/efold/core/datamodule.py index 1f7862c..7eac910 100644 --- a/efold/core/datamodule.py +++ b/efold/core/datamodule.py @@ -1,12 +1,11 @@ -from torch.utils.data import random_split, Subset -import lightning.pytorch as pl -from typing import Union, List -from .dataset import Dataset -from ..config import TEST_SETS, UKN -from .sampler import sampler_factory -from .dataloader import DataLoader -import numpy as np import datetime +from typing import List, Optional, Union + +import lightning.pytorch as pl +from torch.utils.data import Subset + +from efold.constants import config +from efold.core import dataloader, dataset, sampler class DataModule(pl.LightningDataModule): @@ -26,7 +25,7 @@ def __init__( use_error=False, max_len=None, min_len=None, - structure_padding_value=UKN, + structure_padding_value=config.pytorch.unknown_value, tqdm=True, buckets=None, **kwargs, @@ -65,9 +64,9 @@ def __init__( "predict": predict_split, } if strategy in ["ddp", "sorted"]: - assert ( - shuffle_valid == shuffle_train == False - ), "You can't shuffle in ddp or sorted mode. Set shuffle_train and shuffle_valid to 0 or use strategy='random'." + assert shuffle_valid == shuffle_train is False, ( + "You can't shuffle in ddp or sorted mode. Set shuffle_train and shuffle_valid to 0 or use strategy='random'." + ) self.shuffle = { "train": shuffle_train, "valid": shuffle_valid, @@ -79,7 +78,6 @@ def __init__( "use_error": use_error, "force_download": force_download, "tqdm": tqdm, - "max_len": max_len, "min_len": min_len, } self.buckets = buckets @@ -97,18 +95,16 @@ def _use_multiple_datasets(self, name): def _dataset_merge(self, datasets): merge = datasets[0] collate_fn = merge.collate_fn - for dataset in datasets[1:]: - merge = merge + dataset + for ds in datasets[1:]: + merge = merge + ds merge.collate_fn = collate_fn return merge - def setup(self, stage: str = None): - if stage is None or ( - stage in ["fit", "predict"] and not hasattr(self, "all_datasets") - ): + def setup(self, stage: Optional[str] = None): + if stage is None or (stage in ["fit", "predict"] and not hasattr(self, "all_datasets")): self.all_datasets = self._dataset_merge( [ - Dataset.from_local_or_download( + dataset.Dataset.from_local_or_download( name=name, data_type=self.data_type, sort_by_length=self.strategy == "sorted", @@ -120,7 +116,7 @@ def setup(self, stage: str = None): self.collate_fn = self.all_datasets.collate_fn if stage == "fit": - if self.splits["train"] == None or self.splits["train"] == 1.0: + if self.splits["train"] is None or self.splits["train"] == 1.0: self.train_set = self.all_datasets else: num_datapoints = ( @@ -129,9 +125,9 @@ def setup(self, stage: str = None): else self.splits["train"] ) assert num_datapoints > 0, "train_split must be greater than 0" - assert num_datapoints <= len( - self.all_datasets - ), "train_split must be less than the number of datapoints" + assert num_datapoints <= len(self.all_datasets), ( + "train_split must be less than the number of datapoints" + ) self.train_set = Subset( self.all_datasets, range(0, num_datapoints), @@ -140,7 +136,7 @@ def setup(self, stage: str = None): self.external_val_set = [] for name in self.external_valid: self.external_val_set.append( - Dataset.from_local_or_download( + dataset.Dataset.from_local_or_download( name=name, data_type=self.data_type, sort_by_length=True, @@ -164,12 +160,13 @@ def setup(self, stage: str = None): def _select_test_dataset(self): return [ - Dataset.from_local_or_download( + dataset.Dataset.from_local_or_download( name=name, data_type=[data_type], **self.dataset_args, ) - for data_type, datasets in TEST_SETS.items() if data_type in self.data_type + for data_type, datasets in config.test_sets.as_dict.items() + if data_type in self.data_type for name in datasets ] @@ -179,21 +176,21 @@ def train_dataloader(self): raise ValueError( "When using strategy='ddp', the trainer must be passed to the datamodule" ) - else: # ddp + else: # ddp num_replicas = self.trainer.num_devices rank = self.trainer.local_rank else: num_replicas = 1 rank = 0 - return DataLoader( + return dataloader.DataLoader( self.train_set, shuffle=self.shuffle["train"], num_workers=self.num_workers, collate_fn=self.collate_fn, batch_size=self.batch_size, to_device=self.strategy != "ddp", - sampler=sampler_factory( + sampler=sampler.sampler_factory( dataset=self.train_set, strategy=self.strategy, num_replicas=num_replicas, @@ -210,13 +207,13 @@ def val_dataloader(self): if self.external_valid is not None: for val_set in self.external_val_set: val_dls.append( - DataLoader( + dataloader.DataLoader( val_set, shuffle=self.shuffle["valid"], collate_fn=self.collate_fn, batch_size=self.batch_size, to_device=self.strategy != "ddp", - sampler=sampler_factory( + sampler=sampler.sampler_factory( dataset=val_set, strategy=self.strategy, num_replicas=self.trainer.num_devices, @@ -229,7 +226,7 @@ def val_dataloader(self): def test_dataloader(self): return [ - DataLoader( + dataloader.DataLoader( test_set, num_workers=self.num_workers, collate_fn=test_set.collate_fn, @@ -239,7 +236,7 @@ def test_dataloader(self): ] def predict_dataloader(self): - return DataLoader( + return dataloader.DataLoader( self.predict_set, num_workers=self.num_workers, collate_fn=self.collate_fn, diff --git a/efold/core/dataset.py b/efold/core/dataset.py index f52b815..d910588 100644 --- a/efold/core/dataset.py +++ b/efold/core/dataset.py @@ -1,16 +1,12 @@ import os -import numpy as np -import torch -from torch.utils.data import ConcatDataset, Dataset as TorchDataset, Dataset -from typing import List +from typing import List, Optional -from .batch import Batch +import numpy as np from rouskinhf import get_dataset -from .datatype import DMSDataset, SHAPEDataset, StructureDataset -from .embeddings import sequence_to_int -from .util import _pad -from .path import Path -from ..config import UKN +from torch.utils.data import Dataset as TorchDataset + +from efold.constants import config +from efold.core import batch, datatype, path class Dataset(TorchDataset): @@ -25,9 +21,9 @@ def __init__( min_len: int, structure_padding_value: float, use_error: bool, - dms: DMSDataset = None, - shape: SHAPEDataset = None, - structure: StructureDataset = None, + dms: Optional[datatype.DMSDataset] = None, + shape: Optional[datatype.SHAPEDataset] = None, + structure: Optional[datatype.StructureDataset] = None, sort_by_length: bool = False, ) -> None: super().__init__() @@ -43,18 +39,18 @@ def __init__( self.structure_padding_value = structure_padding_value self.L = max(self.length) self._remove_sequences_out_of_length_interval(min_len, max_len) - + if sort_by_length: self.sort() def _remove_sequences_out_of_length_interval(self, min_len, max_len): if max_len is None: - max_len = np.inf + max_len = np.inf if min_len is None: min_len = 0 if min_len > max_len: raise ValueError("min_len must be smaller than max_len") - idx_out = [i for i, l in enumerate(self.length) if l >= max_len or l <= min_len] + idx_out = [i for i, len_ in enumerate(self.length) if len_ >= max_len or len_ <= min_len] for idx in idx_out[::-1]: del self.refs[idx] del self.length[idx] @@ -81,9 +77,7 @@ def __add__(self, other: "Dataset") -> "Dataset": refs=np.concatenate([self.refs, other.refs]).tolist(), length=np.concatenate([self.length, other.length]).tolist(), sequence=np.concatenate([self.sequence, other.sequence]).tolist(), - dms=self.dms + other.dms - if self.dms is not None and other.dms is not None - else None, + dms=self.dms + other.dms if self.dms is not None and other.dms is not None else None, shape=self.shape + other.shape if self.shape is not None and other.shape is not None else None, @@ -101,36 +95,36 @@ def from_local_or_download( use_error: bool = False, max_len=None, min_len=None, - structure_padding_value: float = UKN, + structure_padding_value: float = config.pytorch.unknown_value, sort_by_length: bool = False, tqdm=True, ): - path = Path(name=name) + path_obj = path.Path(name=name) if force_download: - path.clear() + path_obj.clear() - if os.path.exists(path.get_reference()): + if os.path.exists(path_obj.get_reference()): print("Loading dataset from disk") print("Load references \r", end="") - refs = path.load_reference().tolist() + refs = path_obj.load_reference().tolist() print("Load lengths \r", end="") - length = path.load_length().tolist() + length = path_obj.load_length().tolist() L = max(length) print("Load sequences \r", end="") - sequence = path.load_sequence().tolist() + sequence = path_obj.load_sequence().tolist() dms, shape, structure = None, None, None if "dms" in data_type: print("Load dms \r", end="") - dms = path.load_dms() + dms = path_obj.load_dms() if "shape" in data_type: print("Load shape \r", end="") - shape = path.load_shape() + shape = path_obj.load_shape() if "structure" in data_type: print("Load structure \r", end="") - structure = path.load_structure() + structure = path_obj.load_structure() else: data = get_dataset( @@ -142,28 +136,28 @@ def from_local_or_download( print("Dump lengths \r", end="") length = [len(d["sequence"]) for d in data.values()] - path.dump_length(np.array(length)) + path_obj.dump_length(np.array(length)) print("Dump references \r", end="") refs = list(data.keys()) - path.dump_reference(np.array(list(data.keys()))) + path_obj.dump_reference(np.array(list(data.keys()))) L = max(length) print("Dump sequences \r", end="") sequence = [d["sequence"] for d in data.values()] - path.dump_sequence(np.array(sequence)) + path_obj.dump_sequence(np.array(sequence)) print("Dump dms \r", end="") - dms = DMSDataset.from_data_json(data, L, refs) - path.dump_dms(dms) + dms = datatype.DMSDataset.from_data_json(data, L, refs) + path_obj.dump_dms(dms) print("Dump shape \r", end="") - shape = SHAPEDataset.from_data_json(data, L, refs) - path.dump_shape(shape) + shape = datatype.SHAPEDataset.from_data_json(data, L, refs) + path_obj.dump_shape(shape) print("Dump structure \r", end="") - structure = StructureDataset.from_data_json(data, L, refs) - path.dump_structure(structure) + structure = datatype.StructureDataset.from_data_json(data, L, refs) + path_obj.dump_structure(structure) print("Done! ") @@ -205,16 +199,14 @@ def __getitem__(self, index) -> tuple: "length": self.length[index], } for attr in ["dms", "shape", "structure"]: - out[attr] = ( - getattr(self, attr)[index] if getattr(self, attr) != None else None - ) + out[attr] = getattr(self, attr)[index] if getattr(self, attr) is not None else None return out def collate_fn(self, batch_data): - batch = Batch.from_dataset_items( + batch_obj = batch.Batch.from_dataset_items( batch_data, self.data_type, use_error=self.use_error, structure_padding_value=self.structure_padding_value, ) - return batch + return batch_obj diff --git a/efold/core/datatype.py b/efold/core/datatype.py index 1bcee59..2213f94 100644 --- a/efold/core/datatype.py +++ b/efold/core/datatype.py @@ -1,13 +1,14 @@ +from typing import Optional + import torch -from ..config import device, UKN, DTYPE_PER_DATA_TYPE -import torch.nn.functional as F -from .util import _pad + +from efold.constants import config class DataType: attributes = ["true", "pred", "error"] - def __init__(self, true: list, error: list = None, pred: list = None): + def __init__(self, true: list, error: Optional[list] = None, pred: Optional[list] = None): self.true = true self.error = error self.pred = pred @@ -17,7 +18,7 @@ def to(self, device): if hasattr(getattr(self, attr), "to"): setattr(self, attr, getattr(self, attr).to(device)) return self - + def __del__(self): del self.true del self.error @@ -46,9 +47,7 @@ def __getitem__(self, idx): def __add__(self, other): if self.name != other.name: - raise ValueError( - f"Cannot concatenate {self.name} and {other.name} datasets." - ) + raise ValueError(f"Cannot concatenate {self.name} and {other.name} datasets.") if other is None: return self @@ -69,7 +68,7 @@ def __delitem__(self, idx): del self.error[idx] if self.pred is not None: del self.pred[idx] - + def sort(self, idx_sorted): self.true = [self.true[i] for i in idx_sorted] if self.error is not None: @@ -85,16 +84,14 @@ def from_data_json(cls, data_json: dict, L: int, refs: list): values = data_json[ref] if data_type in values: true.append( - torch.tensor( - values[data_type], dtype=DTYPE_PER_DATA_TYPE[data_type] - ) + torch.tensor(values[data_type], dtype=config.data.types_format_torch[data_type]) ) if data_type != "structure": if "error_{}".format(data_type) in values: error.append( torch.tensor( values["error_{}".format(data_type)], - dtype=DTYPE_PER_DATA_TYPE[data_type], + dtype=config.data.types_format_torch[data_type], ) ) else: diff --git a/efold/core/embeddings.py b/efold/core/embeddings.py index 189578a..dcff337 100644 --- a/efold/core/embeddings.py +++ b/efold/core/embeddings.py @@ -1,24 +1,29 @@ -from torch import nn import torch -from ..config import DEFAULT_FORMAT, UKN, seq2int, int2seq +from torch import nn -NUM_BASES = len(set(seq2int.values())) +from efold.constants import config +NUM_BASES = len(set(config.tokens.seq2int.values())) -def sequence_to_int(sequence: str): - return torch.tensor([seq2int[s] for s in sequence], dtype=torch.int64) +def sequence_to_int(sequence: str) -> torch.Tensor: + return torch.tensor([config.tokens.seq2int[s] for s in sequence], dtype=torch.int64) -def int_to_sequence(sequence: torch.tensor): - return "".join([int2seq[i.item()] for i in sequence]) +def int_to_sequence(sequence: torch.Tensor) -> str: + return "".join([config.tokens.int2seq[i.item()] for i in sequence]) -def sequence_to_one_hot(sequence_batch: torch.tensor): - """Converts a sequence to a one-hot encoding""" - return nn.functional.one_hot(sequence_batch, NUM_BASES).type(DEFAULT_FORMAT) +def sequence_to_one_hot(sequence_batch: torch.Tensor) -> torch.Tensor: + return nn.functional.one_hot(sequence_batch, NUM_BASES).type(torch.float32) -def base_pairs_to_pairing_matrix(base_pairs, sequence_length, padding, pad_value=UKN): + +def base_pairs_to_pairing_matrix( + base_pairs: torch.Tensor, + sequence_length: int, + padding: int, + pad_value: float = config.pytorch.unknown_value, +) -> torch.Tensor: pairing_matrix = torch.ones((padding, padding)) * pad_value if base_pairs is None: return pairing_matrix @@ -28,4 +33,3 @@ def base_pairs_to_pairing_matrix(base_pairs, sequence_length, padding, pad_value pairing_matrix[base_pairs[:, 0], base_pairs[:, 1]] = 1.0 pairing_matrix[base_pairs[:, 1], base_pairs[:, 0]] = 1.0 return pairing_matrix - diff --git a/efold/core/loader.py b/efold/core/loader.py index 8c39946..6e3548e 100644 --- a/efold/core/loader.py +++ b/efold/core/loader.py @@ -1,7 +1,9 @@ +import os +from os import listdir, makedirs from os.path import dirname -from os import makedirs, listdir +from typing import Optional + import torch -import os class Loader: @@ -13,31 +15,29 @@ def __init__( makedirs(dirname(self.get_path()), exist_ok=True) @classmethod - def find_best_model(cls, prefix): + def find_best_model(cls, prefix: str) -> Optional["Loader"]: models = [model for model in listdir("models") if model.startswith(prefix)] if len(models) == 0: return None - models.sort( - key=lambda x: float(x.split("_mae:")[-1].split(".")[0].replace("-", ".")) - ) + models.sort(key=lambda x: float(x.split("_mae:")[-1].split(".")[0].replace("-", "."))) return cls(path="models/" + models[0]) - def get_path(self): + def get_path(self) -> str: return self.path - def get_name(self): + def get_name(self) -> str: return self.path.split("/")[-1].split(".")[0] - def write_in_log(self, epoch, mae): + def write_in_log(self, epoch: int, mae: float) -> "Loader": with open("models/_log.txt", "a") as f: f.write(f"{epoch} {self.get_name()}\t{mae}\n") return self - def load_from_weights(self, safe_load=True): + def load_from_weights(self, safe_load: bool = True) -> dict: if safe_load and os.path.exists(self.get_path()) or not safe_load: return torch.load(self.get_path()) raise FileNotFoundError(f"File {self.get_path()} not found") - def dump(self, model): + def dump(self, model) -> "Loader": torch.save(model.state_dict(), self.get_path()) return self diff --git a/efold/core/logger.py b/efold/core/logger.py index 6c248a4..8ae4991 100644 --- a/efold/core/logger.py +++ b/efold/core/logger.py @@ -1,9 +1,8 @@ -import wandb -from ..config import * -import lightning.pytorch as pl import os + +import lightning.pytorch as pl import matplotlib.pyplot as plt -import torchmetrics +import wandb class LocalLogger: @@ -17,9 +16,7 @@ def __init__(self, path: str = "local_testing_output", overwrite: bool = False): def test_plot(self, dataloader, data_type, name, plot: plt.Figure, idx=None): # save the wandb Image to a png - plot.savefig( - os.path.join(self.path, f"{dataloader}_{data_type}_{name}_{idx}.png") - ) + plot.savefig(os.path.join(self.path, f"{dataloader}_{data_type}_{name}_{idx}.png")) plt.close(plot) diff --git a/efold/core/metrics.py b/efold/core/metrics.py index 6f2f25b..1caa73f 100644 --- a/efold/core/metrics.py +++ b/efold/core/metrics.py @@ -1,9 +1,8 @@ -import torch -from ..config import UKN, POSSIBLE_METRICS -import torch -from .batch import Batch import numpy as np -from typing import TypedDict +import torch + +from efold.constants import config +from efold.core import batch # wrapper for metrics @@ -11,7 +10,7 @@ def mask_and_flatten(func): def wrapped(pred, true): if pred is None or true is None: return np.nan - mask = true != UKN + mask = true != config.pytorch.unknown_value if torch.sum(mask) == 0: return np.nan pred = pred[mask] @@ -84,9 +83,7 @@ def r2_score(pred, true): :return: R2 score """ - return ( - 1 - torch.sum((true - pred) ** 2) / torch.sum((true - torch.mean(true)) ** 2) - ).item() + return (1 - torch.sum((true - pred) ** 2) / torch.sum((true - torch.mean(true)) ** 2)).item() @mask_and_flatten @@ -100,9 +97,7 @@ def pearson_coefficient(pred, true): """ return torch.mean( - (pred - torch.mean(pred)) - * (true - torch.mean(true)) - / (torch.std(pred) * torch.std(true)) + (pred - torch.mean(pred)) * (true - torch.mean(true)) / (torch.std(pred) * torch.std(true)) ).item() @@ -135,18 +130,18 @@ def __init__(self, name, data_type=["dms", "shape", "structure"]): self.shape = dict(mae=[], pearson=[], r2=[]) self.structure = dict(f1=[]) - def update(self, batch: Batch): + def update(self, batch: batch.Batch): for dt in self.data_type: pred, true = batch.get_pairs(dt) - for metric in POSSIBLE_METRICS[dt]: + for metric in config.metrics.possible_metrics[dt]: self._add_metric(dt, metric, metric_factory[metric](pred, true)) return self def compute(self) -> dict: - out = {} + out: dict = {} for dt in self.data_type: out[dt] = {} - for metric in POSSIBLE_METRICS[dt]: + for metric in config.metrics.possible_metrics[dt]: out[dt][metric] = self._get_nanmean(dt, metric) return out diff --git a/efold/core/model.py b/efold/core/model.py index f854e95..2193e48 100644 --- a/efold/core/model.py +++ b/efold/core/model.py @@ -1,17 +1,13 @@ from typing import Any + import lightning.pytorch as pl -from lightning.pytorch.utilities.types import STEP_OUTPUT -import torch.nn as nn import torch -from ..config import device, UKN, TEST_SETS_NAMES +import torch.nn as nn import torch.nn.functional as F -from .batch import Batch -from torchmetrics import R2Score, PearsonCorrCoef, MeanAbsoluteError, F1Score -from .metrics import MetricsStack -from .datamodule import DataModule -import time +from lightning.pytorch.utilities.types import STEP_OUTPUT -from .postprocess import Postprocess +from efold.constants import config +from efold.core import batch, metrics, postprocess METRIC_ARGS = dict(dist_sync_on_step=True) @@ -45,16 +41,16 @@ def __init__(self, lr: float, optimizer_fn, weight_data: bool = False, **kwargs) self.automatic_optimization = True self.weight_data = weight_data - self.save_hyperparameters(ignore=['loss_fn']) - self.lossBCE = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([300])).to(device) + self.save_hyperparameters(ignore=["loss_fn"]) + self.lossBCE = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([300])).to(config.device) # Metrics self.metrics_stack = None self.tic = None - self.test_results = {'reference':[], 'sequence':[] ,'structure':[]} + self.test_results: dict[str, list[Any]] = {"reference": [], "sequence": [], "structure": []} - self.postprocesser = Postprocess() + self.postprocesser = postprocess.Postprocess() def configure_optimizers(self): optimizer = self.optimizer_fn( @@ -69,7 +65,7 @@ def configure_optimizers(self): scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.gamma) return [optimizer], [scheduler] - def _loss_signal(self, batch: Batch, data_type: str): + def _loss_signal(self, batch: batch.Batch, data_type: str): assert data_type in [ "dms", "shape", @@ -85,7 +81,7 @@ def _loss_signal(self, batch: Batch, data_type: str): ## vv MSE loss vv ## mask = torch.zeros_like(true) - mask[true != UKN] = 1 + mask[true != config.pytorch.unknown_value] = 1 loss = F.mse_loss(pred * mask, true * mask) non_zeros = (mask == 1).sum() / mask.numel() @@ -96,13 +92,13 @@ def _loss_signal(self, batch: Batch, data_type: str): assert not torch.isnan(loss), "Loss is NaN for {}".format(data_type) return loss - def _loss_structure(self, batch: Batch): + def _loss_structure(self, batch: batch.Batch): pred, true = batch.get_pairs("structure") loss = self.lossBCE(pred, true) assert not torch.isnan(loss), "Loss is NaN for structure" return loss - def loss_fn(self, batch: Batch): + def loss_fn(self, batch: batch.Batch): count = {k: v for k, v in batch.dt_count.items() if k in self.data_type_output} losses = {} if "dms" in count.keys(): @@ -127,23 +123,21 @@ def _clean_predictions(self, batch, predictions): predictions[data_type] = torch.clip(predictions[data_type], min=0, max=1) return predictions - def training_step(self, batch: Batch, batch_idx: int): + def training_step(self, batch: batch.Batch, batch_idx: int): predictions = self.forward(batch) batch.integrate_prediction(predictions) loss = self.loss_fn(batch)[0] - self.log(f"train/loss", loss, sync_dist=True) + self.log("train/loss", loss, sync_dist=True) return loss def on_validation_start(self): val_dl_names = self.trainer.datamodule.external_valid self.metrics_stack = [ - MetricsStack(name=name, data_type=self.data_type_output) + metrics.MetricsStack(name=name, data_type=self.data_type_output) for name in val_dl_names ] - def on_train_batch_end( - self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int - ) -> None: + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: del outputs del batch if batch_idx % 100 == 0: @@ -152,10 +146,12 @@ def on_train_batch_end( def on_train_end(self) -> None: torch.cuda.empty_cache() - def validation_step(self, batch: Batch, batch_idx: int, dataloader_idx=0): + def validation_step(self, batch: batch.Batch, batch_idx: int, dataloader_idx=0): predictions = self.forward(batch) - - predictions['structure'] = self.postprocesser.run(predictions['structure'], batch.get('sequence')) + + predictions["structure"] = self.postprocesser.run( + predictions["structure"], batch.get("sequence") + ) batch.integrate_prediction(predictions) # loss, losses = self.loss_fn(batch) @@ -173,8 +169,8 @@ def on_validation_epoch_end(self) -> None: # aggregate the stack and log it for metrics_dl in self.metrics_stack: metrics_pack = metrics_dl.compute() - for dt, metrics in metrics_pack.items(): - for name, metric in metrics.items(): + for dt, metrics_dict in metrics_pack.items(): + for name, metric in metrics_dict.items(): # to replace with a gather_all? self.log( f"valid/{metrics_dl.name}/{dt}/{name}", @@ -185,14 +181,20 @@ def on_validation_epoch_end(self) -> None: self.metrics_stack = None torch.cuda.empty_cache() - def test_step(self, batch: Batch, batch_idx: int, dataloader_idx=0): + def test_step(self, batch: batch.Batch, batch_idx: int, dataloader_idx=0): predictions = self.forward(batch) - predictions['structure'] = self.postprocesser.run(predictions['structure'], batch.get('sequence')) + predictions["structure"] = self.postprocesser.run( + predictions["structure"], batch.get("sequence") + ) - from ..config import int2seq - self.test_results['reference'] += batch.get('reference') - self.test_results['sequence'] += [''.join([int2seq[base] for base in seq]) for seq in batch.get('sequence').detach().tolist()] - self.test_results['structure'] += predictions['structure'].tolist() + from efold.constants import config + + self.test_results["reference"] += batch.get("reference") + self.test_results["sequence"] += [ + "".join([config.tokens.int2seq[base] for base in seq]) + for seq in batch.get("sequence").detach().tolist() + ] + self.test_results["structure"] += predictions["structure"].tolist() predictions = self._clean_predictions(batch, predictions) batch.integrate_prediction(predictions) @@ -201,12 +203,12 @@ def on_test_batch_end( self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: # push the metric directly - metric_pack = MetricsStack( - name=TEST_SETS_NAMES[dataloader_idx], + metric_pack = metrics.MetricsStack( + name=config.test_sets.all_names[dataloader_idx], data_type=self.data_type_output, - ) - for dt, metrics in metric_pack.update(batch).compute().items(): - for name, metric in metrics.items(): + ) + for dt, metric_dict in metric_pack.update(batch).compute().items(): + for name, metric in metric_dict.items(): self.log( f"test/{metric_pack.name}/{dt}/{name}", float(metric), @@ -218,16 +220,16 @@ def on_test_batch_end( def on_test_epoch_end(self) -> None: torch.cuda.empty_cache() - + def on_test_end(self) -> None: - import pandas as pd + df = pd.DataFrame(self.test_results) - df.to_feather('test_results.feather') + df.to_feather("test_results.feather") torch.cuda.empty_cache() - def predict_step(self, batch: Batch, batch_idx: int): + def predict_step(self, batch: batch.Batch, batch_idx: int): predictions = self.forward(batch) predictions = self._clean_predictions(batch, predictions) batch.integrate_prediction(predictions) diff --git a/efold/core/path.py b/efold/core/path.py index df881ca..4d79cdd 100644 --- a/efold/core/path.py +++ b/efold/core/path.py @@ -1,8 +1,8 @@ -from os.path import join import os -from rouskinhf.env import Env -import numpy as np import pickle +from os.path import join + +import numpy as np from rouskinhf.path import Path as RouskinPath @@ -49,36 +49,30 @@ def get_reference(self) -> str: """Returns the path to the references.txt file.""" return join(self.get_main_folder(), "references.npy") - def load_reference(self): - """Returns the list of references.""" + def load_reference(self) -> np.ndarray: return np.load(self.get_reference(), allow_pickle=True) - def dump_reference(self, references): - """Dumps the list of references.""" + def dump_reference(self, references: np.ndarray) -> None: np.save(self.get_reference(), references) def get_sequence(self) -> str: """Returns the path to the sequences.txt file.""" return join(self.get_main_folder(), "sequences.npy") - def load_sequence(self): - """Returns the list of sequences.""" + def load_sequence(self) -> np.ndarray: return np.load(self.get_sequence(), allow_pickle=True) - def dump_sequence(self, sequences): - """Dumps the list of sequences.""" + def dump_sequence(self, sequences: np.ndarray) -> None: np.save(self.get_sequence(), sequences) def get_length(self) -> str: """Returns the path to the lengths.txt file.""" return join(self.get_main_folder(), "lengths.npy") - def load_length(self): - """Returns the list of lengths.""" + def load_length(self) -> np.ndarray: return np.load(self.get_length(), allow_pickle=True) - def dump_length(self, lengths): - """Dumps the list of lengths.""" + def dump_length(self, lengths: np.ndarray) -> None: np.save(self.get_length(), lengths) def get_dms(self) -> str: @@ -86,14 +80,12 @@ def get_dms(self) -> str: return join(self.get_main_folder(), "dms.pkl") def load_dms(self): - """Returns the list of dms.""" if not os.path.exists(self.get_dms()): return None return pickle.load(open(self.get_dms(), "rb")) @dont_dump_none - def dump_dms(self, val) -> str: - """Dumps the list of dms.""" + def dump_dms(self, val) -> None: pickle.dump(val, open(self.get_dms(), "wb")) def get_shape(self) -> str: @@ -101,14 +93,12 @@ def get_shape(self) -> str: return join(self.get_main_folder(), "shapes.pkl") def load_shape(self): - """Returns the list of shapes.""" if not os.path.exists(self.get_shape()): return None return pickle.load(open(self.get_shape(), "rb")) @dont_dump_none - def dump_shape(self, val): - """Dumps the list of shapes.""" + def dump_shape(self, val) -> None: pickle.dump(val, open(self.get_shape(), "wb")) def get_structure(self) -> str: @@ -116,12 +106,10 @@ def get_structure(self) -> str: return join(self.get_main_folder(), "structures.pkl") def load_structure(self): - """Returns the list of structures.""" if not os.path.exists(self.get_structure()): return None return pickle.load(open(self.get_structure(), "rb")) @dont_dump_none - def dump_structure(self, val): - """Dumps the list of structures.""" + def dump_structure(self, val) -> None: pickle.dump(val, open(self.get_structure(), "wb")) diff --git a/efold/core/postprocess.py b/efold/core/postprocess.py index 1448544..0dd6c91 100644 --- a/efold/core/postprocess.py +++ b/efold/core/postprocess.py @@ -1,13 +1,13 @@ -import torch -import math import numpy as np +import torch import torch.nn.functional as F from scipy.optimize import linear_sum_assignment -from ..config import seq2int -class Constraints: +from efold.constants import config + - """ Compute a mask of constraints to remove sharp loops and non-canonical pairs from the input matrix +class Constraints: + """Compute a mask of constraints to remove sharp loops and non-canonical pairs from the input matrix Args: - input_matrix (torch.Tensor): n x n matrix of base pair probabilities @@ -17,58 +17,58 @@ class Constraints: Example: >>> inpt = torch.tensor([[0.1, 0.6, 0.8],[0.6, 0.1, 0.9],[0.8, 0.9, 0.1]]) - >>> sequence = torch.tensor([seq2int[a] for a in "GCU"]) + >>> sequence = torch.tensor([config.tokens.seq2int[a] for a in "GCU"]) >>> out = Constraints().apply_constraints(inpt, sequence=sequence, min_hairpin_length=0, canonical_only=True) >>> assert (out == torch.tensor([[0.0, 0.6, 0.8],[0.6, 0.0, 0.0],[0.8, 0.0, 0.0]])).all(), "The output is not as expected: {}".format(out) """ - def apply_constraints(self, input_matrix, min_hairpin_length=3, canonical_only=True, sequence=None): - + def apply_constraints( + self, input_matrix, min_hairpin_length=3, canonical_only=True, sequence=None + ): # mask elements of the diagonals and sub-diagonals mask = self.mask_sharpLoops(input_matrix, min_hairpin_length) # Mask elements of the matrix that are not A-U, G-C, or G-U pairs using the sequence - if canonical_only: mask *= self.mask_nonCanonical(sequence) - + if canonical_only: + mask *= self.mask_nonCanonical(sequence) + return input_matrix * mask def mask_sharpLoops(self, input_matrix, min_hairpin_length): - # mask elements of the diagonals and sub-diagonals - mask = np.tri(input_matrix.shape[0], k=-min_hairpin_length-1).astype(int) + mask = np.tri(input_matrix.shape[0], k=-min_hairpin_length - 1).astype(int) return torch.tensor(mask + mask.T, device=input_matrix.device) def mask_nonCanonical(self, sequence): - # Embed sequence - if type(sequence) == str: sequence = torch.tensor([seq2int[a] for a in sequence]) - + if isinstance(sequence, str): + sequence = torch.tensor([config.tokens.seq2int[a] for a in sequence]) + # make the pairing matrix sequence = sequence.reshape(-1, 1) pair_of_bases = sequence + sequence.T # find the allowable pairs allowable_pair = set() - for pair in ["GU", "GC", "AU"]: allowable_pair.add(seq2int[pair[0]] + seq2int[pair[1]]) + for pair in ["GU", "GC", "AU"]: + allowable_pair.add(config.tokens.seq2int[pair[0]] + config.tokens.seq2int[pair[1]]) allowable_pair = torch.tensor(list(allowable_pair), device=pair_of_bases.device) return torch.isin(pair_of_bases, allowable_pair).int() - class HungarianAlgorithm: - def run(self, bppm, threshold=0.5): """Runs the Hungarian algorithm on the input bppm matrix - + Args: - bppm (torch.Tensor): n x n matrix of base pair probabilities - + Example: >>> inpt = np.diag(np.ones(10))[::-1] >>> inpt += np.random.normal(0, 0.2, inpt.shape) - >>> inpt = (inpt + inpt.T)/2 + >>> inpt = (inpt + inpt.T)/2 >>> inpt = torch.tensor(inpt) >>> out = HungarianAlgorithm().run(inpt) >>> assert (out == [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],\ @@ -84,25 +84,25 @@ def run(self, bppm, threshold=0.5): >>> out = HungarianAlgorithm().run(torch.tensor([[0., 0.6, 0.8],[0.6, 0., 0.9],[0.8, 0.9, 0.]]) ) >>> assert (out == [[0., 1., 1.],[1., 0., 1.],[1., 1., 0.]]).all(), "The output is not as expected: {}".format(out) """ - + assert len(bppm.shape) == 2, "The input bppm matrix should be n x n" assert bppm.shape[0] == bppm.shape[1], "The input bppm matrix should be n x n" assert self.is_symmetric(bppm), f"The input bppm matrix should be symmetric, {bppm}" - + # just work with numpy (needed for the optimization step) - if type(bppm)==torch.Tensor: + if isinstance(bppm, torch.Tensor): device = bppm.device bppm = bppm.cpu().numpy() - - # run hungarian algorithm - bp_matrix = np.zeros(bppm.shape) - + + # run hungarian algorithm + bp_matrix = np.zeros(bppm.shape) + # run hungarian algorithm only on rows and columns that have at least one value greater than threshold compression_idx = self._pairable_bases(bppm, threshold) compressed_bppm = bppm[compression_idx][:, compression_idx] - row_ind, col_ind = self._hungarian_algorithm(compressed_bppm) - + row_ind, col_ind = self._hungarian_algorithm(compressed_bppm) + # convert the result to the original sized matrix for compressed_row, compressed_col in zip(row_ind, col_ind): bppm_row, bppm_col = compression_idx[compressed_row], compression_idx[compressed_col] @@ -111,31 +111,29 @@ def run(self, bppm, threshold=0.5): bp_matrix[bppm_col, bppm_row] = 1 return torch.tensor(bp_matrix, device=device) - + def _hungarian_algorithm(self, cost_matrix): """Returns the row and column indices of the optimal assignment using the Hungarian algorithm""" row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True) return row_ind, col_ind - + def _pairable_bases(self, bppm, threshold): """Returns the indices of rows that have at least one value greater than threshold - + Example: >>> assert (HungarianAlgorithm()._pairable_bases(np.array([[0.0, 0.0, 0.0], [0.0, 0.6, 0.8], [0.0, 0.8, 0.9]]), 0.01) == np.array([1, 2])).all(), "The output is not as expected" """ return np.where((bppm > threshold).any(axis=0))[0] - + def is_symmetric(self, bppm): return (bppm == bppm.transpose(1, 0)).all() - class UFold_processing: - def run(self, bppm): return self.postprocess(u=bppm) - - def postprocess(self, u, lr_min=0.01, lr_max=0.1, num_itr=100, rho=1.6, with_l1=True,s=1.5): + + def postprocess(self, u, lr_min=0.01, lr_max=0.1, num_itr=100, rho=1.6, with_l1=True, s=1.5): """ :param u: utility matrix, u is assumed to be symmetric :param lr_min: learning rate for minimization step @@ -145,10 +143,11 @@ def postprocess(self, u, lr_min=0.01, lr_max=0.1, num_itr=100, rho=1.6, with_l1= :param with_l1: :return: """ + def soft_sign(x): k = 1 - return 1.0/(1.0+torch.exp(-2*k*x)) - + return 1.0 / (1.0 + torch.exp(-2 * k * x)) + def contact_a(a_hat, m): a = a_hat * a_hat a = (a + torch.transpose(a, -1, -2)) / 2 @@ -167,8 +166,9 @@ def contact_a(a_hat, m): # gradient descent for t in range(num_itr): - - grad_a = (lmbd * soft_sign(torch.sum(contact_a(a_hat, m), dim=-1) - 1)).unsqueeze_(-1).expand(u.shape) - u / 2 + grad_a = (lmbd * soft_sign(torch.sum(contact_a(a_hat, m), dim=-1) - 1)).unsqueeze_( + -1 + ).expand(u.shape) - u / 2 grad = a_hat * m * (grad_a + torch.transpose(grad_a, -1, -2)) a_hat -= lr_min * grad lr_min = lr_min * 0.99 @@ -186,40 +186,40 @@ def contact_a(a_hat, m): # grad_a = (lmbd * soft_sign(torch.sum(contact_a(a_hat, m), dim=-1) - 1)).unsqueeze_(-1).expand(u.shape) - u / 2 # grad = a_hat * m * (grad_a + torch.transpose(grad_a, -1, -2)) # n2 = torch.norm(grad) - # print([t, 'norms', n1, n2, aug_lagrangian(u, m, a_hat, lmbd), torch.sum(contact_a(a_hat, u))]) + # print([t, 'norms', n1, n2, aug_lagrangian(u, m, a_hat, lmbd), + # torch.sum(contact_a(a_hat, u))]) a = a_hat * a_hat a = (a + torch.transpose(a, -1, -2)) / 2 a = a * m return a - -class Postprocess: +class Postprocess: def __init__(self, threshold=0.5, canonical_only=True, min_hairpin_length=3): self.threshold = threshold self.canonical_only = canonical_only self.min_hairpin_length = min_hairpin_length def run(self, bppms, sequence): - if len(bppms.shape) == 2: bppms = bppms.unsqueeze(0) pairing_matrices = [] for bppm in bppms: + pairing_matrix = Constraints().apply_constraints( + bppm, + sequence=sequence, + min_hairpin_length=self.min_hairpin_length, + canonical_only=self.canonical_only, + ) - pairing_matrix = Constraints().apply_constraints(bppm, sequence=sequence, - min_hairpin_length=self.min_hairpin_length, - canonical_only=self.canonical_only) - pairing_matrix_UFold = UFold_processing().run(pairing_matrix) - if not pairing_matrix_UFold.isnan().any(): pairing_matrix = pairing_matrix_UFold - + if not pairing_matrix_UFold.isnan().any(): + pairing_matrix = pairing_matrix_UFold pairing_matrix = HungarianAlgorithm().run(pairing_matrix, threshold=self.threshold) pairing_matrices.append(pairing_matrix) return (torch.stack(pairing_matrices) > self.threshold).type(torch.int) - diff --git a/efold/core/sampler.py b/efold/core/sampler.py index d76c1ac..49cd5f8 100644 --- a/efold/core/sampler.py +++ b/efold/core/sampler.py @@ -1,17 +1,18 @@ -from torch.utils.data import Sampler, Subset -import numpy as np -# from random import shuffle -from torch.utils.data import Dataset -from typing import Union, Optional, TypeVar, Iterator -import torch.distributed as dist import math -import torch import os +from typing import Iterator, Optional, TypeVar, Union -T_co = TypeVar('T_co', covariant=True) +import numpy as np +import torch +import torch.distributed as dist + +# from random import shuffle +from torch.utils.data import Dataset, Sampler, Subset +T_co = TypeVar("T_co", covariant=True) -class DDPSampler(Sampler): + +class DDPSampler(Sampler): r"""Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with @@ -58,12 +59,17 @@ class DDPSampler(Sampler): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader) - """ - def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, - rank: Optional[int] = None, shuffle: bool = True, - seed: int = os.environ.get('PL_GLOBAL_SEED', 0), - drop_last: bool = False) -> None: - + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = int(os.environ.get("PL_GLOBAL_SEED", 0)), + drop_last: bool = False, + ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -74,7 +80,9 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( - f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}] because num_replicas={num_replicas}") + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]", + f"because num_replicas={num_replicas}", + ) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -94,12 +102,14 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed - + if isinstance(dataset, Subset): - self.length = dataset.dataset.length[slice(dataset.indices.start, dataset.indices.stop, dataset.indices.step)] + self.length = dataset.dataset.length[ + slice(dataset.indices.start, dataset.indices.stop, dataset.indices.step) + ] elif isinstance(dataset, Dataset): - self.length = dataset.length - + self.length = dataset.length + def __iter__(self) -> Iterator[T_co]: # deterministically shuffle based on epoch and seed if self.shuffle: @@ -118,24 +128,24 @@ def __iter__(self) -> Iterator[T_co]: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples # sort by length - lengths = [self.length[i] for i in indices] + lengths = [self.length[i] for i in indices] length_order = np.argsort(lengths) indices = [indices[i] for i in length_order] - + # shuffle them again to avoid having the samples sorted by length g = torch.Generator() g.manual_seed(self.seed + self.epoch + 42) deterministic_order = torch.randperm(len(indices), generator=g).tolist() indices = [indices[i] for i in deterministic_order] - + return iter(indices) def __len__(self) -> int: @@ -156,14 +166,13 @@ def set_epoch(self, epoch: int) -> None: def sampler_factory( dataset: Union[Dataset, Subset], strategy: str, - seed:int = os.environ.get('PL_GLOBAL_SEED', 0), + seed: int = int(os.environ.get("PL_GLOBAL_SEED", 0)), num_replicas: Optional[int] = None, rank: Optional[int] = None, ): - if strategy in ['random', 'sorted']: + if strategy in ["random", "sorted"]: return None - elif strategy == 'ddp': + elif strategy == "ddp": return DDPSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=True, seed=seed) else: raise ValueError(f"Invalid strategy value: {strategy}") - \ No newline at end of file diff --git a/efold/core/util.py b/efold/core/util.py index 0bd09d2..6da8a48 100644 --- a/efold/core/util.py +++ b/efold/core/util.py @@ -1,25 +1,25 @@ -from ..config import UKN -from .embeddings import base_pairs_to_pairing_matrix -import torch.nn.functional as F -from torch import tensor import torch +import torch.nn.functional as F + +from efold.constants import config +from efold.core import embeddings -def _pad(arr, L, data_type): +def _pad(arr: torch.Tensor, L: int, data_type: str) -> torch.Tensor: padding_values = { "sequence": 0, - "dms": UKN, - "shape": UKN, + "dms": config.pytorch.unknown_value, + "shape": config.pytorch.unknown_value, } if data_type == "structure": - return base_pairs_to_pairing_matrix(arr, L) + return embeddings.base_pairs_to_pairing_matrix(arr, L) else: if isinstance(arr, list): arr = torch.tensor(arr) return F.pad(arr, (0, L - arr.shape[1]), value=padding_values[data_type]) -def split_data_type(data_type): +def split_data_type(data_type: str) -> tuple[str, str]: if "_" not in data_type: data_part = "true" else: diff --git a/efold/core/visualisation.py b/efold/core/visualisation.py index 01c78ef..e257c16 100644 --- a/efold/core/visualisation.py +++ b/efold/core/visualisation.py @@ -1,10 +1,10 @@ -from matplotlib import pyplot as plt import numpy as np -import wandb -from .metrics import r2_score, mae_score, pearson_coefficient -from ..config import UKN +from matplotlib import pyplot as plt from rouskinhf import int2seq +from efold.constants import config +from efold.core import metrics + matplotlib_colors = [ "red", "blue", @@ -38,9 +38,9 @@ def plot_signal( fig, ax = plt.subplots() # Compute metrics while you still have tensors - r2 = r2_score(pred=pred, true=true) - r = pearson_coefficient(pred=pred, true=true) - mae = mae_score(pred=pred, true=true) + r2 = metrics.r2_score(pred=pred, true=true) + r = metrics.pearson_coefficient(pred=pred, true=true) + mae = metrics.mae_score(pred=pred, true=true) # Base position with no coverage or G/U base are removed def chop_array(x): @@ -50,7 +50,7 @@ def known_bases_to_list(x, mask): return x[mask].cpu().numpy() pred, true, sequence = chop_array(pred), chop_array(true), chop_array(sequence) - mask = true != UKN + mask = true != config.pytorch.unknown_value true, pred, sequence = ( known_bases_to_list(true, mask), known_bases_to_list(pred, mask), @@ -129,11 +129,7 @@ def plot_structure(pred, true, **kwargs): plot_factory = { - ("dms", "scatter"): lambda *args, **kwargs: plot_signal( - *args, **kwargs, data_type="DMS" - ), - ("shape", "scatter"): lambda *args, **kwargs: plot_signal( - *args, **kwargs, data_type="SHAPE" - ), + ("dms", "scatter"): lambda *args, **kwargs: plot_signal(*args, **kwargs, data_type="DMS"), + ("shape", "scatter"): lambda *args, **kwargs: plot_signal(*args, **kwargs, data_type="SHAPE"), ("structure", "heatmap"): plot_structure, } diff --git a/efold/models/__init__.py b/efold/models/__init__.py index 594023d..e69de29 100644 --- a/efold/models/__init__.py +++ b/efold/models/__init__.py @@ -1 +0,0 @@ -from .factory import create_model diff --git a/efold/models/cnn.py b/efold/models/cnn.py index d664df3..e53a20c 100644 --- a/efold/models/cnn.py +++ b/efold/models/cnn.py @@ -1,19 +1,13 @@ -import torch -from torch import nn, Tensor -from torch.nn import TransformerEncoderLayer import numpy as np -import os -import sys -from ..core.model import Model -from ..core.batch import Batch -from einops import rearrange +import torch import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, nn -dir_name = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(dir_name, "..")) +from efold.core import batch, model -class CNN(Model): +class CNN(model.Model): def __init__( self, ntoken: int, @@ -89,7 +83,7 @@ def __init__( # nn.Linear(d_model, 1), # ) - def forward(self, batch: Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> dict[str, Tensor]: """ Args: src: Tensor, shape [seq_len, batch_size] @@ -105,9 +99,7 @@ def forward(self, batch: Batch) -> Tensor: # Outer concatenation matrix = x.unsqueeze(1).repeat(1, x.shape[1], 1, 1) # (N, L, L, d_cnn/2) - matrix = torch.cat( - (matrix, matrix.permute(0, 2, 1, 3)), dim=-1 - ) # (N, L, L, d_cnn) + matrix = torch.cat((matrix, matrix.permute(0, 2, 1, 3)), dim=-1) # (N, L, L, d_cnn) # Resnet layers matrix = self.res_layers(matrix.permute(0, 3, 1, 2)) # (N, d_cnn//8, L, L) @@ -225,9 +217,7 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): self.res_blocks = nn.Sequential(*self.res_layers) # Adapter to change depth - self.conv_output = nn.Conv2d( - dim_in, dim_out, kernel_size=7, padding=3, bias=True - ) + self.conv_output = nn.Conv2d(dim_in, dim_out, kernel_size=7, padding=3, bias=True) def forward(self, x: Tensor) -> Tensor: x = self.res_blocks(x) @@ -252,14 +242,10 @@ def __init__( self.bn1 = nn.BatchNorm2d(inplanes) self.relu1 = nn.ReLU(inplace=True) - self.conv1 = conv3x3( - inplanes, planes, dilation=dilation1, kernel_size=kernel_size - ) + self.conv1 = conv3x3(inplanes, planes, dilation=dilation1, kernel_size=kernel_size) self.dropout = nn.Dropout(p=dropout) self.relu2 = nn.ReLU(inplace=True) - self.conv2 = conv3x3( - planes, planes, dilation=dilation2, kernel_size=kernel_size - ) + self.conv2 = conv3x3(planes, planes, dilation=dilation2, kernel_size=kernel_size) def forward(self, x: Tensor) -> Tensor: identity = x @@ -276,9 +262,7 @@ def forward(self, x: Tensor) -> Tensor: return out -def conv3x3( - in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3 -) -> nn.Conv2d: +def conv3x3(in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, diff --git a/efold/models/efold.py b/efold/models/efold.py index cdda582..d11400b 100644 --- a/efold/models/efold.py +++ b/efold/models/efold.py @@ -1,24 +1,17 @@ -import numpy as np -import torch -from torch import nn, Tensor -import os -import sys +from collections import defaultdict from contextlib import ExitStack +from typing import List, Union -import typing as T -from einops import rearrange -import torch.nn.functional as F import numpy as np -from ..core.batch import Batch -from ..core.model import Model - -from collections import defaultdict +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, nn -dir_name = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(dir_name, "..")) +from efold.core import batch, model -class eFold(Model): +class eFold(model.Model): def __init__( self, ntoken: int, @@ -34,16 +27,14 @@ def __init__( optimizer_fn=torch.optim.Adam, **kwargs, ): - self.save_hyperparameters(ignore=['loss_fn']) - super().__init__( - lr=lr, loss_fn=loss_fn, optimizer_fn=optimizer_fn, **kwargs - ) + self.save_hyperparameters(ignore=["loss_fn"]) + super().__init__(lr=lr, loss_fn=loss_fn, optimizer_fn=optimizer_fn, **kwargs) self.model_type = "eFold" self.data_type_output = ["structure"] self.lr = lr self.gamma = gamma - self.train_losses = [] + self.train_losses: list[float] = [] self.loss = nn.MSELoss() # Encoder layers @@ -86,17 +77,15 @@ def __init__( kernel_size=3, dropout=dropout, ), - ResLayer( - dim_in=d_cnn // 2, dim_out=1, n_blocks=4, kernel_size=3, dropout=dropout - ), + ResLayer(dim_in=d_cnn // 2, dim_out=1, n_blocks=4, kernel_size=3, dropout=dropout), ) - def forward(self, batch: Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> dict[str, Tensor]: # Encoding of RNA sequence src = batch.get("sequence") - + s = self.encoder(src) # (N, L, d_model) - z = self.encoder_adapter(self.seq2map(src)).permute(0, 2, 3, 1) # (N, L, L, d_model) + z = self.encoder_adapter(self.seq2map(src)).permute(0, 2, 3, 1) # (N, L, L, d_model) # z = self.activ(self.encoder_adapter(s)) # (N, L, c_z / 2) # # Outer concatenation @@ -106,63 +95,75 @@ def forward(self, batch: Batch) -> Tensor: s, z = self.eFold(s, z) structure = self.structure_adapter(z) # (N, L, L, d_cnn) - structure = self.output_structure(structure.permute(0, 3, 1, 2)).squeeze( - 1 - ) # (N, L, L) + structure = self.output_structure(structure.permute(0, 3, 1, 2)).squeeze(1) # (N, L, L) return { # "dms": self.output_net_DMS(s).squeeze(axis=2), # "shape": self.output_net_SHAPE(s).squeeze(axis=2), - "structure": (structure + structure.permute(0, 2, 1)) - / 2, + "structure": (structure + structure.permute(0, 2, 1)) / 2, } - - def seq2map(self, seq_int): + def seq2map(self, seq_int): def int2seq(seq): # return ''.join(['XAUCG'[d] for d in seq]) - return ''.join(['XACGU'[d] for d in seq]) + return "".join(["XACGU"[d] for d in seq]) # take integer encoded sequence and return last channel of embedding (pairing energy) def creatmat(data, device=None): - with torch.no_grad(): data = int2seq(data) - paired = defaultdict(float, {'AU':2., 'UA':2., 'GC':3., 'CG':3., 'UG':0.8, 'GU':0.8}) + paired = defaultdict( + float, {"AU": 2.0, "UA": 2.0, "GC": 3.0, "CG": 3.0, "UG": 0.8, "GU": 0.8} + ) - mat = torch.tensor([[paired[x+y] for y in data] for x in data]).to(device) + mat = torch.tensor([[paired[x + y] for y in data] for x in data]).to(device) n = len(data) - i, j = torch.meshgrid(torch.arange(n).to(device), torch.arange(n).to(device), indexing='ij') + i, j = torch.meshgrid( + torch.arange(n).to(device), torch.arange(n).to(device), indexing="ij" + ) t = torch.arange(30).to(device) - m1 = torch.where((i[:, :, None] - t >= 0) & (j[:, :, None] + t < n), mat[torch.clamp(i[:,:,None]-t, 0, n-1), torch.clamp(j[:,:,None]+t, 0, n-1)], 0) - m1 *= torch.exp(-0.5*t*t) + m1 = torch.where( + (i[:, :, None] - t >= 0) & (j[:, :, None] + t < n), + mat[ + torch.clamp(i[:, :, None] - t, 0, n - 1), + torch.clamp(j[:, :, None] + t, 0, n - 1), + ], + 0, + ) + m1 *= torch.exp(-0.5 * t * t) m1_0pad = torch.nn.functional.pad(m1, (0, 1)) - first0 = torch.argmax((m1_0pad==0).to(int), dim=2) - to0indices = t[None,None,:]>first0[:,:,None] + first0 = torch.argmax((m1_0pad == 0).to(int), dim=2) + to0indices = t[None, None, :] > first0[:, :, None] m1[to0indices] = 0 m1 = m1.sum(dim=2) t = torch.arange(1, 30).to(device) - m2 = torch.where((i[:, :, None] + t < n) & (j[:, :, None] - t >= 0), mat[torch.clamp(i[:,:,None]+t, 0, n-1), torch.clamp(j[:,:,None]-t, 0, n-1)], 0) - m2 *= torch.exp(-0.5*t*t) + m2 = torch.where( + (i[:, :, None] + t < n) & (j[:, :, None] - t >= 0), + mat[ + torch.clamp(i[:, :, None] + t, 0, n - 1), + torch.clamp(j[:, :, None] - t, 0, n - 1), + ], + 0, + ) + m2 *= torch.exp(-0.5 * t * t) m2_0pad = torch.nn.functional.pad(m2, (0, 1)) - first0 = torch.argmax((m2_0pad==0).to(int), dim=2) - to0indices = torch.arange(29).to(device)[None,None,:]>first0[:,:,None] + first0 = torch.argmax((m2_0pad == 0).to(int), dim=2) + to0indices = torch.arange(29).to(device)[None, None, :] > first0[:, :, None] m2[to0indices] = 0 m2 = m2.sum(dim=2) - m2[m1==0] = 0 + m2[m1 == 0] = 0 - return (m1+m2).to(self.device) + return (m1 + m2).to(self.device) # Assemble all data full_map = [] one_hot_embed = torch.zeros((5, 4), device=self.device) one_hot_embed[1:] = torch.eye(4) for seq in seq_int: - seq_hot = one_hot_embed[seq].type(torch.long) pair_map = torch.kron(seq_hot, seq_hot).reshape(len(seq), len(seq), 16) @@ -170,7 +171,6 @@ def creatmat(data, device=None): full_map.append(torch.cat((pair_map, energy_map.unsqueeze(-1)), dim=-1)) - return torch.stack(full_map).permute(0, 3, 1, 2).contiguous() @@ -213,9 +213,7 @@ def __init__( self.pos = PositionalEncoding(self.c_s, dropout) self.ln = nn.LayerNorm(self.c_s, eps=1e-12, elementwise_affine=True) - self.resNet = ResLayer( - dim_in=c_z, dim_out=c_z, n_blocks=2, kernel_size=3, dropout=dropout - ) + self.resNet = ResLayer(dim_in=c_z, dim_out=c_z, n_blocks=2, kernel_size=3, dropout=dropout) # self.tri_mul_out = TriangleMultiplicationOutgoing( # c_z, @@ -330,9 +328,7 @@ def forward(self, sequence_state, pairwise_state): # Update pairwise state pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) - pairwise_state = self.resNet(pairwise_state.permute(0, 3, 1, 2)).permute( - 0, 2, 3, 1 - ) + pairwise_state = self.resNet(pairwise_state.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # # Axial attention # pairwise_state = pairwise_state + self.row_drop( @@ -477,7 +473,7 @@ class Dropout(nn.Module): along a particular dimension. """ - def __init__(self, r: float, batch_dim: T.Union[int, T.List[int]]): + def __init__(self, r: float, batch_dim: Union[int, List[int]]): super(Dropout, self).__init__() self.r = r @@ -621,7 +617,6 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): # Basic Residula block self.res_layers = [] for i in range(n_blocks): - # dilation = pow(2, (i % 3)) self.res_layers.append( ResBlock( inplanes=dim_in, @@ -635,9 +630,7 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): self.res_blocks = nn.Sequential(*self.res_layers) # Adapter to change depth - self.conv_output = nn.Conv2d( - dim_in, dim_out, kernel_size=7, padding=3, bias=True - ) + self.conv_output = nn.Conv2d(dim_in, dim_out, kernel_size=7, padding=3, bias=True) def forward(self, x: Tensor) -> Tensor: x = self.res_blocks(x) @@ -662,14 +655,10 @@ def __init__( self.bn1 = nn.BatchNorm2d(inplanes) self.relu1 = nn.ReLU(inplace=True) - self.conv1 = conv3x3( - inplanes, planes, dilation=dilation1, kernel_size=kernel_size - ) + self.conv1 = conv3x3(inplanes, planes, dilation=dilation1, kernel_size=kernel_size) self.dropout = nn.Dropout(p=dropout) self.relu2 = nn.ReLU(inplace=True) - self.conv2 = conv3x3( - planes, planes, dilation=dilation2, kernel_size=kernel_size - ) + self.conv2 = conv3x3(planes, planes, dilation=dilation2, kernel_size=kernel_size) def forward(self, x: Tensor) -> Tensor: identity = x @@ -686,9 +675,7 @@ def forward(self, x: Tensor) -> Tensor: return out -def conv3x3( - in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3 -) -> nn.Conv2d: +def conv3x3(in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, @@ -747,16 +734,13 @@ def __init__( self.dropout = nn.Dropout(dropout) self._dropout_rate = dropout - input_max = (self.num_heads * self.head_size) ** -0.5 ### DOUBLE CHECK THIS CODE: self.query = nn.Linear(num_heads * head_size, num_heads * head_size, bias=False) self.key = nn.Linear(num_heads * head_size, num_heads * head_size, bias=False) self.value = nn.Linear(num_heads * head_size, num_heads * head_size, bias=False) ########################### - self.projection_kernel = nn.Parameter( - torch.rand(num_heads, head_size, output_size) * 2 - 1 - ) + self.projection_kernel = nn.Parameter(torch.rand(num_heads, head_size, output_size) * 2 - 1) if use_projection_bias: self.projection_bias = nn.Parameter(torch.rand(output_size) * 2 - 1) @@ -786,10 +770,9 @@ def call_qkv(self, query, key, value, training=False): return query, key, value - def call_attention( - self, query, key, value, logits, bias=None, training=False, mask=None - ): - # Mask = attention mask with shape [B, Tquery, Tkey] with 1 for positions we want to attend, 0 for masked + def call_attention(self, query, key, value, logits, bias=None, training=False, mask=None): + # Mask = attention mask with shape [B, Tquery, Tkey] with 1 for positions we want to attend, + # 0 for masked if mask is not None: if len(mask.size()) < 2: raise ValueError("'mask' must have at least 2 dimensions") @@ -822,15 +805,11 @@ def call_attention( attn_coef_dropout = self.dropout(attn_coef) # Attention * value - multihead_output = torch.einsum( - "...HNM,...MHI->...NHI", attn_coef_dropout, value - ) + multihead_output = torch.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value) # Run the outputs through another linear projection layer. Recombining heads # is automatically done. - output = torch.einsum( - "...NHI,HIO->...NO", multihead_output, self.projection_kernel - ) + output = torch.einsum("...NHI,HIO->...NO", multihead_output, self.projection_kernel) if self.projection_bias is not None: output += self.projection_bias @@ -865,7 +844,6 @@ def __init__(self, kernel_sizes=None, strides=None, **kwargs): super(RelPositionMultiHeadAttention, self).__init__(**kwargs) num_pos_features = self.num_heads * self.head_size - input_max = (self.num_heads * self.head_size) ** -0.5 self.pos_kernel = nn.Parameter( torch.rand(self.num_heads, num_pos_features, self.head_size) * 2 - 1 ) @@ -943,10 +921,6 @@ def __init__( self.scale = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) - pw1_max = input_dim**-0.5 - dw_max = kernel_size**-0.5 - pw2_max = input_dim**-0.5 - self.pw_conv_1 = nn.Conv1d( in_channels=input_dim, out_channels=conv_expansion_rate * input_dim, diff --git a/efold/models/factory.py b/efold/models/factory.py index ed60e99..cc62c7e 100644 --- a/efold/models/factory.py +++ b/efold/models/factory.py @@ -1,19 +1,15 @@ -from .transformer import Transformer -from .efold import eFold -from .cnn import CNN -from .ribonanza import Ribonanza -from .unet import U_Net +from efold.models import cnn, efold, ribonanza, transformer, unet def create_model(model: str, *args, **kwargs): if model == "transformer": - return Transformer(*args, **kwargs) + return transformer.Transformer(*args, **kwargs) if model == "efold": - return eFold(*args, **kwargs) + return efold.eFold(*args, **kwargs) if model == "cnn": - return CNN(*args, **kwargs) + return cnn.CNN(*args, **kwargs) if model == "unet": - return U_Net(*args, **kwargs) + return unet.U_Net(*args, **kwargs) if model == "ribonanza": - return Ribonanza(*args, **kwargs) + return ribonanza.Ribonanza(*args, **kwargs) raise ValueError(f"Unknown model: {model}") diff --git a/efold/models/ribonanza.py b/efold/models/ribonanza.py index 5909dec..057a3b8 100644 --- a/efold/models/ribonanza.py +++ b/efold/models/ribonanza.py @@ -1,9 +1,10 @@ -from torch import nn, tensor import torch -from ..config import device, seq2int, START_TOKEN, END_TOKEN, PADDING_TOKEN -from ..core.model import Model +from torch import nn from torch.nn import init +from efold.constants import config +from efold.core import model + global_gain = 0.1 @@ -76,9 +77,6 @@ def forward(self, structure): return structure * weights.unsqueeze(-1).unsqueeze(-1) -from torch.nn.functional import multi_head_attention_forward - - class SelfAttention(nn.Module): def __init__( self, @@ -168,11 +166,11 @@ def sequence_batch(batch): out.append( torch.concat( [ - tensor([START_TOKEN], dtype=torch.long).to(device), + torch.tensor([config.tokens.start_token], dtype=torch.long).to(config.device), sequence[:length], - tensor([END_TOKEN], dtype=torch.long).to(device), - tensor([PADDING_TOKEN] * (L - length), dtype=torch.long).to( - device + torch.tensor([config.tokens.end_token], dtype=torch.long).to(config.device), + torch.tensor([config.tokens.padding_token] * (L - length), dtype=torch.long).to( + config.device ), ], ) @@ -182,9 +180,9 @@ def sequence_batch(batch): def structure_batch(batch): structure = batch.get("structure") batch_size, L, _ = structure.shape - embedded_matrix = torch.zeros( - (batch_size, L + 2, L + 2), dtype=torch.float32 - ).to(device) + embedded_matrix = torch.zeros((batch_size, L + 2, L + 2), dtype=torch.float32).to( + config.device + ) embedded_matrix[:, 1:-1, 1:-1] = structure return embedded_matrix @@ -205,7 +203,7 @@ def forward(self, sequence, structure): return sequence, structure -class Ribonanza(Model): +class Ribonanza(model.Model): ntokens = 7 data_type = ["dms", "shape"] diff --git a/efold/models/transformer.py b/efold/models/transformer.py index d2601ae..13449fb 100644 --- a/efold/models/transformer.py +++ b/efold/models/transformer.py @@ -1,17 +1,12 @@ +import numpy as np import torch -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn import TransformerEncoderLayer -import numpy as np -import os -import sys -from ..core.model import Model -from ..core.batch import Batch -dir_name = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(dir_name, "..")) +from efold.core import batch, model -class Transformer(Model): +class Transformer(model.Model): def __init__( self, ntoken: int, @@ -82,9 +77,7 @@ def __init__( assert c_z % 4 == 0, "c_z must be divisible by 4" assert c_z >= 8, "c_z must be greater than 8" self.output_net_structure = nn.Sequential( - ResLayer( - n_blocks=4, dim_in=c_z, dim_out=c_z // 2, kernel_size=3, dropout=dropout - ), + ResLayer(n_blocks=4, dim_in=c_z, dim_out=c_z // 2, kernel_size=3, dropout=dropout), ResLayer( n_blocks=4, dim_in=c_z // 2, @@ -92,12 +85,10 @@ def __init__( kernel_size=3, dropout=dropout, ), - ResLayer( - n_blocks=4, dim_in=c_z // 4, dim_out=1, kernel_size=3, dropout=dropout - ), + ResLayer(n_blocks=4, dim_in=c_z // 4, dim_out=1, kernel_size=3, dropout=dropout), ) - def forward(self, batch: Batch) -> Tensor: + def forward(self, batch: batch.Batch) -> dict[str, Tensor]: """ Args: src: Tensor, shape [seq_len, batch_size] @@ -109,7 +100,7 @@ def forward(self, batch: Batch) -> Tensor: src = self.encoder(src) src = self.pos_encoder(src) - for i, l in enumerate(self.transformer_encoder): + for i, _ in enumerate(self.transformer_encoder): src = self.transformer_encoder[i](src) src = self.resnet(src.unsqueeze(dim=1)).squeeze(dim=1) @@ -120,14 +111,10 @@ def forward(self, batch: Batch) -> Tensor: # Outer concatenation src = self.activ(self.encoder_adapter(src)) matrix = src.unsqueeze(1).repeat(1, src.shape[1], 1, 1) # (N, d_cnn/2, L, L) - matrix = torch.cat( - (matrix, matrix.permute(0, 2, 1, 3)), dim=-1 - ) # (N, d_cnn, L, L) + matrix = torch.cat((matrix, matrix.permute(0, 2, 1, 3)), dim=-1) # (N, d_cnn, L, L) # Resnet layers - pair_prob = self.output_net_structure(matrix.permute(0, 3, 1, 2)).squeeze( - 1 - ) # (N, L, L) + pair_prob = self.output_net_structure(matrix.permute(0, 3, 1, 2)).squeeze(1) # (N, L, L) # Symmetrize structure = (pair_prob + pair_prob.permute(0, 2, 1)) / 2 # (N, L, L) @@ -169,7 +156,7 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): # Basic Residula block self.res_layers = [] for i in range(n_blocks): - dilation = pow(2, (i % 3)) + pow(2, (i % 3)) self.res_layers.append( ResBlock( inplanes=dim_in, @@ -183,9 +170,7 @@ def __init__(self, n_blocks, dim_in, dim_out, kernel_size, dropout=0.0): self.res_blocks = nn.Sequential(*self.res_layers) # Adapter to change depth - self.conv_output = nn.Conv2d( - dim_in, dim_out, kernel_size=7, padding=3, bias=True - ) + self.conv_output = nn.Conv2d(dim_in, dim_out, kernel_size=7, padding=3, bias=True) def forward(self, x: Tensor) -> Tensor: x = self.res_blocks(x) @@ -210,9 +195,7 @@ def __init__( self.bn1 = nn.BatchNorm2d(inplanes) self.relu1 = nn.ReLU(inplace=True) - self.conv1 = conv3x3( - inplanes, planes, dilation=dilation1, kernel_size=kernel_size - ) + self.conv1 = conv3x3(inplanes, planes, dilation=dilation1, kernel_size=kernel_size) self.dropout = nn.Dropout(p=dropout) self.relu2 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d( @@ -239,9 +222,7 @@ def forward(self, x: Tensor) -> Tensor: return out -def conv3x3( - in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3 -) -> nn.Conv2d: +def conv3x3(in_planes: int, out_planes: int, dilation: int = 1, kernel_size=3) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, diff --git a/efold/models/unet.py b/efold/models/unet.py index 4f2bd8f..0e9c7b3 100644 --- a/efold/models/unet.py +++ b/efold/models/unet.py @@ -1,62 +1,55 @@ -import torch -from torch import nn, Tensor -import torch.nn.functional as F -from torch.nn import init - -from ..core.model import Model -from ..core.batch import Batch - -from ..config import int2seq +from collections import defaultdict -import os, sys - -from collections import defaultdict +import torch +from torch import Tensor, nn -dir_name = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(dir_name, "..")) +from efold.constants import config +from efold.core import batch, model CH_FOLD2 = 1 -class U_Net(Model): - def __init__(self, img_ch=3, output_ch=1, lr: float = 1e-5, optimizer_fn=torch.optim.Adam, **kwargs): - +class U_Net(model.Model): + def __init__( + self, img_ch=3, output_ch=1, lr: float = 1e-5, optimizer_fn=torch.optim.Adam, **kwargs + ): super().__init__(lr=lr, optimizer_fn=optimizer_fn, **kwargs) self.model_type = "Unet" self.data_type_output = ["structure"] - - self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) - self.Conv1 = conv_block(ch_in=img_ch,ch_out=int(32*CH_FOLD2)) - self.Conv2 = conv_block(ch_in=int(32*CH_FOLD2),ch_out=int(64*CH_FOLD2)) - self.Conv3 = conv_block(ch_in=int(64*CH_FOLD2),ch_out=int(128*CH_FOLD2)) - self.Conv4 = conv_block(ch_in=int(128*CH_FOLD2),ch_out=int(256*CH_FOLD2)) - self.Conv5 = conv_block(ch_in=int(256*CH_FOLD2),ch_out=int(512*CH_FOLD2)) + self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) - self.Up5 = up_conv(ch_in=int(512*CH_FOLD2),ch_out=int(256*CH_FOLD2)) - self.Up_conv5 = conv_block(ch_in=int(512*CH_FOLD2), ch_out=int(256*CH_FOLD2)) + self.Conv1 = conv_block(ch_in=img_ch, ch_out=int(32 * CH_FOLD2)) + self.Conv2 = conv_block(ch_in=int(32 * CH_FOLD2), ch_out=int(64 * CH_FOLD2)) + self.Conv3 = conv_block(ch_in=int(64 * CH_FOLD2), ch_out=int(128 * CH_FOLD2)) + self.Conv4 = conv_block(ch_in=int(128 * CH_FOLD2), ch_out=int(256 * CH_FOLD2)) + self.Conv5 = conv_block(ch_in=int(256 * CH_FOLD2), ch_out=int(512 * CH_FOLD2)) - self.Up4 = up_conv(ch_in=int(256*CH_FOLD2),ch_out=int(128*CH_FOLD2)) - self.Up_conv4 = conv_block(ch_in=int(256*CH_FOLD2), ch_out=int(128*CH_FOLD2)) - - self.Up3 = up_conv(ch_in=int(128*CH_FOLD2),ch_out=int(64*CH_FOLD2)) - self.Up_conv3 = conv_block(ch_in=int(128*CH_FOLD2), ch_out=int(64*CH_FOLD2)) - - self.Up2 = up_conv(ch_in=int(64*CH_FOLD2),ch_out=int(32*CH_FOLD2)) - self.Up_conv2 = conv_block(ch_in=int(64*CH_FOLD2), ch_out=int(32*CH_FOLD2)) + self.Up5 = up_conv(ch_in=int(512 * CH_FOLD2), ch_out=int(256 * CH_FOLD2)) + self.Up_conv5 = conv_block(ch_in=int(512 * CH_FOLD2), ch_out=int(256 * CH_FOLD2)) - self.Conv_1x1 = nn.Conv2d(int(32*CH_FOLD2),output_ch,kernel_size=1,stride=1,padding=0) + self.Up4 = up_conv(ch_in=int(256 * CH_FOLD2), ch_out=int(128 * CH_FOLD2)) + self.Up_conv4 = conv_block(ch_in=int(256 * CH_FOLD2), ch_out=int(128 * CH_FOLD2)) + self.Up3 = up_conv(ch_in=int(128 * CH_FOLD2), ch_out=int(64 * CH_FOLD2)) + self.Up_conv3 = conv_block(ch_in=int(128 * CH_FOLD2), ch_out=int(64 * CH_FOLD2)) - def forward(self, batch: Batch) -> Tensor: + self.Up2 = up_conv(ch_in=int(64 * CH_FOLD2), ch_out=int(32 * CH_FOLD2)) + self.Up_conv2 = conv_block(ch_in=int(64 * CH_FOLD2), ch_out=int(32 * CH_FOLD2)) + self.Conv_1x1 = nn.Conv2d(int(32 * CH_FOLD2), output_ch, kernel_size=1, stride=1, padding=0) + + def forward(self, batch: batch.Batch) -> dict[str, Tensor]: src = batch.get("sequence") padd_multiple = 32 pad_len = 0 - if src.shape[1]%padd_multiple: - pad_len = (padd_multiple-src.shape[1]%padd_multiple) - src = torch.cat( (src, torch.zeros((src.shape[0], pad_len), device=self.device, dtype=torch.long) ), dim=-1) + if src.shape[1] % padd_multiple: + pad_len = padd_multiple - src.shape[1] % padd_multiple + src = torch.cat( + (src, torch.zeros((src.shape[0], pad_len), device=self.device, dtype=torch.long)), + dim=-1, + ) # def get_cut_len(data_len,set_len): # l = data_len @@ -67,7 +60,8 @@ def forward(self, batch: Batch) -> Tensor: # return l # pad_len = get_cut_len(src.shape[1], 80)-src.shape[1] - # src = torch.cat( (src, torch.zeros((src.shape[0], pad_len), device=self.device, dtype=torch.long) ), dim=-1) + # src = torch.cat( (src, torch.zeros((src.shape[0], pad_len), device=self.device, + # dtype=torch.long) ), dim=-1) x = self.seq2map(src) @@ -78,7 +72,7 @@ def forward(self, batch: Batch) -> Tensor: x2 = self.Maxpool(x1) x2 = self.Conv2(x2) - + x3 = self.Maxpool(x2) x3 = self.Conv3(x3) @@ -90,20 +84,20 @@ def forward(self, batch: Batch) -> Tensor: # decoding + concat path d5 = self.Up5(x5) - d5 = torch.cat((x4,d5),dim=1) - + d5 = torch.cat((x4, d5), dim=1) + d5 = self.Up_conv5(d5) - + d4 = self.Up4(d5) - d4 = torch.cat((x3,d4),dim=1) + d4 = torch.cat((x3, d4), dim=1) d4 = self.Up_conv4(d4) d3 = self.Up3(d4) - d3 = torch.cat((x2,d3),dim=1) + d3 = torch.cat((x2, d3), dim=1) d3 = self.Up_conv3(d3) d2 = self.Up2(d3) - d2 = torch.cat((x1,d2),dim=1) + d2 = torch.cat((x1, d2), dim=1) d2 = self.Up_conv2(d2) d1 = self.Conv_1x1(d2) @@ -111,53 +105,65 @@ def forward(self, batch: Batch) -> Tensor: structure = torch.transpose(d1, -1, -2) * d1 - return { - "structure": structure[:, :src.shape[1]-pad_len, :src.shape[1]-pad_len] - } - + return {"structure": structure[:, : src.shape[1] - pad_len, : src.shape[1] - pad_len]} def seq2map(self, seq_int): - # take integer encoded sequence and return last channel of embedding (pairing energy) def creatmat(data, device=None): - with torch.no_grad(): - data = ''.join([int2seq[d] for d in data.tolist()]) - paired = defaultdict(float, {'AU':2., 'UA':2., 'GC':3., 'CG':3., 'UG':0.8, 'GU':0.8}) + data = "".join([config.tokens.int2seq[d] for d in data.tolist()]) + paired = defaultdict( + float, {"AU": 2.0, "UA": 2.0, "GC": 3.0, "CG": 3.0, "UG": 0.8, "GU": 0.8} + ) - mat = torch.tensor([[paired[x+y] for y in data] for x in data]).to(device) + mat = torch.tensor([[paired[x + y] for y in data] for x in data]).to(device) n = len(data) - i, j = torch.meshgrid(torch.arange(n).to(device), torch.arange(n).to(device), indexing='ij') + i, j = torch.meshgrid( + torch.arange(n).to(device), torch.arange(n).to(device), indexing="ij" + ) t = torch.arange(30).to(device) - m1 = torch.where((i[:, :, None] - t >= 0) & (j[:, :, None] + t < n), mat[torch.clamp(i[:,:,None]-t, 0, n-1), torch.clamp(j[:,:,None]+t, 0, n-1)], 0) - m1 *= torch.exp(-0.5*t*t) + m1 = torch.where( + (i[:, :, None] - t >= 0) & (j[:, :, None] + t < n), + mat[ + torch.clamp(i[:, :, None] - t, 0, n - 1), + torch.clamp(j[:, :, None] + t, 0, n - 1), + ], + 0, + ) + m1 *= torch.exp(-0.5 * t * t) m1_0pad = torch.nn.functional.pad(m1, (0, 1)) - first0 = torch.argmax((m1_0pad==0).to(int), dim=2) - to0indices = t[None,None,:]>first0[:,:,None] + first0 = torch.argmax((m1_0pad == 0).to(int), dim=2) + to0indices = t[None, None, :] > first0[:, :, None] m1[to0indices] = 0 m1 = m1.sum(dim=2) t = torch.arange(1, 30).to(device) - m2 = torch.where((i[:, :, None] + t < n) & (j[:, :, None] - t >= 0), mat[torch.clamp(i[:,:,None]+t, 0, n-1), torch.clamp(j[:,:,None]-t, 0, n-1)], 0) - m2 *= torch.exp(-0.5*t*t) + m2 = torch.where( + (i[:, :, None] + t < n) & (j[:, :, None] - t >= 0), + mat[ + torch.clamp(i[:, :, None] + t, 0, n - 1), + torch.clamp(j[:, :, None] - t, 0, n - 1), + ], + 0, + ) + m2 *= torch.exp(-0.5 * t * t) m2_0pad = torch.nn.functional.pad(m2, (0, 1)) - first0 = torch.argmax((m2_0pad==0).to(int), dim=2) - to0indices = torch.arange(29).to(device)[None,None,:]>first0[:,:,None] + first0 = torch.argmax((m2_0pad == 0).to(int), dim=2) + to0indices = torch.arange(29).to(device)[None, None, :] > first0[:, :, None] m2[to0indices] = 0 m2 = m2.sum(dim=2) - m2[m1==0] = 0 + m2[m1 == 0] = 0 - return (m1+m2).to(self.device) + return (m1 + m2).to(self.device) # Assemble all data full_map = [] one_hot_embed = torch.zeros((5, 4), device=self.device) one_hot_embed[1:] = torch.eye(4) for seq in seq_int: - seq_hot = one_hot_embed[seq].type(torch.long) pair_map = torch.kron(seq_hot, seq_hot).reshape(len(seq), len(seq), 16) @@ -165,46 +171,41 @@ def creatmat(data, device=None): full_map.append(torch.cat((pair_map, energy_map.unsqueeze(-1)), dim=-1)) - return torch.stack(full_map).permute(0, 3, 1, 2).contiguous() - class conv_block(nn.Module): - def __init__(self,ch_in,ch_out): - super(conv_block,self).__init__() + def __init__(self, ch_in, ch_out): + super(conv_block, self).__init__() self.conv = nn.Sequential( - nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), - nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) - - def forward(self,x): + def forward(self, x): x = self.conv(x) return x + class up_conv(nn.Module): - def __init__(self,ch_in,ch_out): - super(up_conv,self).__init__() + def __init__(self, ch_in, ch_out): + super(up_conv, self).__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2), - nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), - nn.BatchNorm2d(ch_out), - nn.ReLU(inplace=True) + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), ) - def forward(self,x): + def forward(self, x): x = self.up(x) return x - - - # class U_Net_FP(nn.Module): # def __init__(self,img_ch=17,output_ch=1): # super(U_Net_FP, self).__init__() @@ -277,4 +278,4 @@ def forward(self,x): # d1 = self.Conv_1x1(d2) # d1 = d1.squeeze(1) -# return torch.transpose(d1, -1, -2) * d1 \ No newline at end of file +# return torch.transpose(d1, -1, -2) * d1 diff --git a/efold/util/__init__.py b/efold/util/__init__.py index f64f073..e69de29 100644 --- a/efold/util/__init__.py +++ b/efold/util/__init__.py @@ -1,21 +0,0 @@ -from torch.optim import Adam -from torch import nn - -str2fun = { - "adam": Adam, - "Adam": Adam, - "mse_loss": nn.functional.mse_loss, - "l1_loss": nn.functional.l1_loss, -} - - -def unzip(f): - def wrapper(*args, **kwargs): - out = f(*args, **kwargs) - if not hasattr(out, "__iter__"): - return out - if len(out) == 1: - return out[0] - return out - - return wrapper diff --git a/efold/util/format_conversion.py b/efold/util/format_conversion.py index f1f46de..8bbb61c 100644 --- a/efold/util/format_conversion.py +++ b/efold/util/format_conversion.py @@ -1,11 +1,10 @@ # this code wsa taken from arnie_utils.py +from typing import Optional -def convert_bp_list_to_dotbracket(bp_list, seq_len): - - # convert the bp list from 1-indexed to 0-indexed - bp_list = [(b-1, c-1) for b, c in bp_list] - +def convert_bp_list_to_dotbracket(bp_list: list[tuple[int, int]], seq_len: int) -> Optional[str]: + bp_list = [(b - 1, c - 1) for b, c in bp_list] + db = "." * seq_len # group into bps that are not intertwined and can use same brackets! groups = _group_into_non_conflicting_bp(bp_list) @@ -13,8 +12,9 @@ def convert_bp_list_to_dotbracket(bp_list, seq_len): # all bp that are not intertwined get (), but all others are # groups to be nonconflicting and then asigned (), [], {}, <> by group chars_set = [("(", ")"), ("(", ")"), ("[", "]"), ("{", "}"), ("<", ">")] - alphabet = [(chr(lower), chr(upper)) - for upper, lower in zip(list(range(65, 91)), list(range(97, 123)))] + alphabet = [ + (chr(lower), chr(upper)) for upper, lower in zip(list(range(65, 91)), list(range(97, 123))) + ] chars_set.extend(alphabet) if len(groups) > len(chars_set): @@ -23,25 +23,23 @@ def convert_bp_list_to_dotbracket(bp_list, seq_len): for group, chars in zip(groups, chars_set): for bp in group: - db = db[:bp[0]] + chars[0] + \ - db[bp[0] + 1:bp[1]] + chars[1] + db[bp[1] + 1:] + db = db[: bp[0]] + chars[0] + db[bp[0] + 1 : bp[1]] + chars[1] + db[bp[1] + 1 :] return db -def _group_into_non_conflicting_bp(bp_list): - ''' given a bp_list, group basepairs into groups that do not conflict +def _group_into_non_conflicting_bp(bp_list: list[tuple[int, int]]) -> list[list[tuple[int, int]]]: + """given a bp_list, group basepairs into groups that do not conflict Args bp_list: list of base_pairs Returns: groups of baspairs that are not intertwined - ''' + """ conflict_list = _get_list_bp_conflicts(bp_list) non_redudant_bp_list = _get_non_redudant_bp_list(conflict_list) - bp_with_no_conflict = [ - bp for bp in bp_list if bp not in non_redudant_bp_list] + bp_with_no_conflict = [bp for bp in bp_list if bp not in non_redudant_bp_list] groups = [bp_with_no_conflict] while non_redudant_bp_list != []: current_bp = non_redudant_bp_list[0] @@ -51,8 +49,7 @@ def _group_into_non_conflicting_bp(bp_list): current_bp_conflicts.append(conflict[1]) elif current_bp == conflict[1]: current_bp_conflicts.append(conflict[0]) - max_group = [ - bp for bp in non_redudant_bp_list if bp not in current_bp_conflicts] + max_group = [bp for bp in non_redudant_bp_list if bp not in current_bp_conflicts] to_remove = [] for i, bpA in enumerate(max_group): for bpB in max_group[i:]: @@ -62,19 +59,22 @@ def _group_into_non_conflicting_bp(bp_list): group = [bp for bp in max_group if bp not in to_remove] groups.append(group) non_redudant_bp_list = current_bp_conflicts - conflict_list = [conflict for conflict in conflict_list if conflict[0] - not in group and conflict[1] not in group] + conflict_list = [ + conflict + for conflict in conflict_list + if conflict[0] not in group and conflict[1] not in group + ] return groups -def _get_non_redudant_bp_list(conflict_list): - ''' given a conflict list get the list of nonredundant basepairs this list has +def _get_non_redudant_bp_list(conflict_list: list) -> list: + """given a conflict list get the list of nonredundant basepairs this list has Args: conflict_list: list of pairs of base_pairs that are intertwined basepairs returns: list of basepairs in conflict list without repeats - ''' + """ non_redudant_bp_list = [] for conflict in conflict_list: if conflict[0] not in non_redudant_bp_list: @@ -84,20 +84,21 @@ def _get_non_redudant_bp_list(conflict_list): return non_redudant_bp_list - -def _get_list_bp_conflicts(bp_list): - '''given a bp_list gives the list of conflicts bp-s which indicate PK structure +def _get_list_bp_conflicts(bp_list: list[tuple[int, int]]) -> list: + """given a bp_list gives the list of conflicts bp-s which indicate PK structure Args: - bp_list: of list of base pairs where the base pairs are list of indeces of the bp in increasing order (bp[0] threshold).float()\n", "\n", - "\n", - " TP = torch.sum(pred_matrix*target_matrix)\n", + " TP = torch.sum(pred_matrix * target_matrix)\n", " PP = torch.sum(pred_matrix)\n", " P = torch.sum(target_matrix)\n", " sum_pair = PP + P\n", @@ -102,11 +102,8 @@ " if sum_pair == 0:\n", " return [1.0, 1.0, 1.0]\n", " else:\n", - " return [\n", - " (TP / PP).item(),\n", - " (TP / P).item(),\n", - " (2 * TP / sum_pair).item()\n", - " ]\n", + " return [(TP / PP).item(), (TP / P).item(), (2 * TP / sum_pair).item()]\n", + "\n", "\n", "def compute_confusion_matrix(label, pred):\n", " true_negatives = (1 - label) * (1 - pred)\n", @@ -114,7 +111,9 @@ " false_positives = (1 - label) * pred\n", " false_negatives = label * (1 - pred)\n", " confusion_matrix = true_positives + false_positives * 2 + false_negatives * 3\n", - " assert ((true_negatives == 1) == (confusion_matrix == 0)).all(), \"True negatives are not correctly computed\"\n", + " assert ((true_negatives == 1) == (confusion_matrix == 0)).all(), (\n", + " \"True negatives are not correctly computed\"\n", + " )\n", " return confusion_matrix" ] }, @@ -124,10 +123,14 @@ "metadata": {}, "outputs": [], "source": [ - "ground_truth['non_canonical'] = ground_truth.apply(lambda x: ratio_nonCanonical(x['sequence'], x['structure']), axis=1)\n", - "ground_truth['sharp_loops'] = ground_truth.apply(lambda x: ratio_sharpLoops(x['structure']), axis=1)\n", - "ground_truth[ground_truth['sharp_loops'] >0]\n", - "ground_truth['pairing_matrix'] = ground_truth.apply(lambda x: ListofPairs2pairMatrix(np.array(x['structure']), len(x['sequence'])), axis=1)" + "ground_truth[\"non_canonical\"] = ground_truth.apply(\n", + " lambda x: ratio_nonCanonical(x[\"sequence\"], x[\"structure\"]), axis=1\n", + ")\n", + "ground_truth[\"sharp_loops\"] = ground_truth.apply(lambda x: ratio_sharpLoops(x[\"structure\"]), axis=1)\n", + "ground_truth[ground_truth[\"sharp_loops\"] > 0]\n", + "ground_truth[\"pairing_matrix\"] = ground_truth.apply(\n", + " lambda x: ListofPairs2pairMatrix(np.array(x[\"structure\"]), len(x[\"sequence\"])), axis=1\n", + ")" ] }, { @@ -137,8 +140,10 @@ "outputs": [], "source": [ "import time\n", + "\n", "from tqdm import tqdm\n", - "from efold import inference\n", + "\n", + "from efold.api.run import run\n", "\n", "eFold_processed = pd.DataFrame()\n", "\n", @@ -154,15 +159,17 @@ " dTs = []\n", "\n", " for idx, row in tqdm(ground_truth.iterrows(), total=len(ground_truth)):\n", - " true_structure = torch.tensor(row['structure'])\n", - " sequence = row['sequence']\n", + " true_structure = torch.tensor(row[\"structure\"])\n", + " sequence = row[\"sequence\"]\n", "\n", " t0 = time.time()\n", - " prediction = torch.tensor(inference(sequence, fmt='bp')[sequence])-1\n", + " prediction = torch.tensor(run.run(sequence, fmt=\"bp\")[sequence]) - 1\n", " dT = time.time() - t0\n", "\n", - " precision, recall, f1 = compute_f1(ListofPairs2pairMatrix(prediction, len(sequence)), \n", - " ListofPairs2pairMatrix(true_structure, len(sequence)))\n", + " precision, recall, f1 = compute_f1(\n", + " ListofPairs2pairMatrix(prediction, len(sequence)),\n", + " ListofPairs2pairMatrix(true_structure, len(sequence)),\n", + " )\n", "\n", " Precisions.append(precision)\n", " Recalls.append(recall)\n", @@ -170,12 +177,31 @@ " predictions.append(prediction)\n", " dTs.append(dT)\n", "\n", - " eFold_processed = pd.concat([eFold_processed, pd.DataFrame({'reference': ground_truth.index, 'sequence': ground_truth['sequence'],\n", - " 'threshold': threshold,\n", - " 'precision': Precisions, 'recall': Recalls, 'f1': F1s, 'dT': dTs,\n", - " 'structure': predictions})], axis=0)\n", + " eFold_processed = pd.concat(\n", + " [\n", + " eFold_processed,\n", + " pd.DataFrame(\n", + " {\n", + " \"reference\": ground_truth.index,\n", + " \"sequence\": ground_truth[\"sequence\"],\n", + " \"threshold\": threshold,\n", + " \"precision\": Precisions,\n", + " \"recall\": Recalls,\n", + " \"f1\": F1s,\n", + " \"dT\": dTs,\n", + " \"structure\": predictions,\n", + " }\n", + " ),\n", + " ],\n", + " axis=0,\n", + " )\n", "# Add dataset name\n", - "eFold_processed = eFold_processed.merge(ground_truth.reset_index().rename(columns={'index':'reference'})[['reference', 'sequence', 'dataset']], on=['reference', 'sequence'])\n", + "eFold_processed = eFold_processed.merge(\n", + " ground_truth.reset_index().rename(columns={\"index\": \"reference\"})[\n", + " [\"reference\", \"sequence\", \"dataset\"]\n", + " ],\n", + " on=[\"reference\", \"sequence\"],\n", + ")\n", "# eFold_processed['basePairs'] = eFold_processed['structure'].apply(lambda x: torch.unique(torch.sort(torch.stack(torch.where(x>0)).T, dim=1)[0], dim=0))" ] }, @@ -185,13 +211,21 @@ "metadata": {}, "outputs": [], "source": [ - "eFold_processed['non_canonical'] = eFold_processed.apply(lambda x: ratio_nonCanonical(x['sequence'], x['structure']), axis=1)\n", - "eFold_processed['sharp_loops'] = eFold_processed.apply(lambda x: ratio_sharpLoops(x['structure']), axis=1)\n", - "eFold_processed['pairing_matrix'] = eFold_processed.apply(lambda x: ListofPairs2pairMatrix(x['structure'], len(x['sequence'])), axis=1)\n", - "eFold_processed['multiPairs'] = eFold_processed['pairing_matrix'].apply(lambda x: (x.sum(axis=0) >1).sum().item()/len(x) )\n", - "eFold_processed['length'] = eFold_processed['sequence'].apply(len)\n", - "\n", - "eFold_processed.groupby('threshold')[['non_canonical', 'sharp_loops', 'multiPairs']].mean()" + "eFold_processed[\"non_canonical\"] = eFold_processed.apply(\n", + " lambda x: ratio_nonCanonical(x[\"sequence\"], x[\"structure\"]), axis=1\n", + ")\n", + "eFold_processed[\"sharp_loops\"] = eFold_processed.apply(\n", + " lambda x: ratio_sharpLoops(x[\"structure\"]), axis=1\n", + ")\n", + "eFold_processed[\"pairing_matrix\"] = eFold_processed.apply(\n", + " lambda x: ListofPairs2pairMatrix(x[\"structure\"], len(x[\"sequence\"])), axis=1\n", + ")\n", + "eFold_processed[\"multiPairs\"] = eFold_processed[\"pairing_matrix\"].apply(\n", + " lambda x: (x.sum(axis=0) > 1).sum().item() / len(x)\n", + ")\n", + "eFold_processed[\"length\"] = eFold_processed[\"sequence\"].apply(len)\n", + "\n", + "eFold_processed.groupby(\"threshold\")[[\"non_canonical\", \"sharp_loops\", \"multiPairs\"]].mean()" ] }, { @@ -201,10 +235,12 @@ "outputs": [], "source": [ "# Group the data by model and dataset and calculate the mean for each group\n", - "grouped = eFold_processed.groupby(['threshold', 'dataset']).mean(numeric_only=True).reset_index()\n", + "grouped = eFold_processed.groupby([\"threshold\", \"dataset\"]).mean(numeric_only=True).reset_index()\n", "\n", "# Pivot the table to create a multi-level column structure\n", - "pivot_df = pd.pivot_table(grouped, index='threshold', columns='dataset', values=['precision', 'recall', 'f1'])\n", + "pivot_df = pd.pivot_table(\n", + " grouped, index=\"threshold\", columns=\"dataset\", values=[\"precision\", \"recall\", \"f1\"]\n", + ")\n", "\n", "# Swap the level of the columns to have dataset as the top level and the metrics as the second level\n", "pivot_df = pivot_df.swaplevel(i=0, j=1, axis=1).sort_index(axis=1)\n", @@ -213,16 +249,21 @@ "# new_order = ['SimpleThreshold', 'HungarianAlgorithm', 'UFold_processing', 'OptimalProcessing']\n", "# pivot_df = pivot_df.reindex(new_order)\n", "\n", - "pivot_df = pivot_df.reindex(columns=pivot_df.columns.reindex(['precision', 'recall', 'f1'], level=1)[0])[['PDB', 'archiveII_blast', 'viral_fragments', 'lncRNA_nonFiltered']]\n", - "\n", - "pivot_df = pivot_df.style\\\n", - " .format(precision=3)\\\n", - " .highlight_max(axis=0, props=\"font-weight:bold;font-color:black;\")\\\n", - " .background_gradient(axis=1, vmin=-0.1, vmax=1, cmap=\"viridis\", text_color_threshold=0)\\\n", - " .set_properties(**{'text-align': 'center'})\\\n", - " .set_table_styles(\n", - " [{\"selector\": \"th\", \"props\": [('text-align', 'center')]},\n", - " ])\n", + "pivot_df = pivot_df.reindex(\n", + " columns=pivot_df.columns.reindex([\"precision\", \"recall\", \"f1\"], level=1)[0]\n", + ")[[\"PDB\", \"archiveII_blast\", \"viral_fragments\", \"lncRNA_nonFiltered\"]]\n", + "\n", + "pivot_df = (\n", + " pivot_df.style.format(precision=3)\n", + " .highlight_max(axis=0, props=\"font-weight:bold;font-color:black;\")\n", + " .background_gradient(axis=1, vmin=-0.1, vmax=1, cmap=\"viridis\", text_color_threshold=0)\n", + " .set_properties(**{\"text-align\": \"center\"})\n", + " .set_table_styles(\n", + " [\n", + " {\"selector\": \"th\", \"props\": [(\"text-align\", \"center\")]},\n", + " ]\n", + " )\n", + ")\n", "pivot_df" ] } diff --git a/tests/test_speed_eFold.py b/tests/test_speed_eFold.py index 6e9d9f2..c29ced8 100644 --- a/tests/test_speed_eFold.py +++ b/tests/test_speed_eFold.py @@ -1,19 +1,14 @@ -import sys, os -file_dir = os.path.dirname(os.path.realpath(__file__)) -sys.path.append(os.path.join(file_dir, '../../eFold')) +import os +import time -from efold import inference -import pandas as pd import numpy as np -from rouskinhf import get_dataset -import torch - -import time +from rnastructure_wrapper import RNAstructure from tqdm import tqdm +import plotly.graph_objects as go -from rnastructure_wrapper import RNAstructure +from efold.api import run as run_module -Fold = RNAstructure(path='/root/RNAstructure/exe/') +Fold = RNAstructure(path="/root/RNAstructure/exe/") rnaStructure_dTs = [] efold_GPU_dTs = [] @@ -21,38 +16,40 @@ lengths = np.linspace(10, 500, 100).astype(int) for length in tqdm(lengths): - - sequence_random = ''.join(np.random.choice(['A', 'C', 'G', 'U'], length)) + sequence_random = "".join(np.random.choice(["A", "C", "G", "U"], length)) # eFold GPU t0 = time.time() - inference(sequence_random, fmt='bp') - dT = time.time()-t0 + run_module.run(sequence_random, fmt="bp") + dT = time.time() - t0 efold_GPU_dTs.append(dT) # RNAfold t0 = time.time() - inference(sequence_random, fmt='bp', device='cpu') - dT = time.time()-t0 + run_module.run(sequence_random, fmt="bp", device="cpu") + dT = time.time() - t0 efold_CPU_dTs.append(dT) # RNAstructure no MFE t0 = time.time() Fold.fold(sequence_random, mfe_only=False) - dT = time.time()-t0 + dT = time.time() - t0 rnaStructure_dTs.append(dT) -import plotly.graph_objects as go - fig = go.Figure() -fig.add_trace(go.Scatter(x=lengths, y=rnaStructure_dTs, mode='lines', name='RNAstructure')) -fig.add_trace(go.Scatter(x=lengths, y=efold_CPU_dTs, mode='lines', name='eFold (CPU)')) -fig.add_trace(go.Scatter(x=lengths, y=efold_GPU_dTs, mode='lines', name='eFold (GPU)')) -fig.update_layout(title='Inference time of eFold vs RNAstructure', - xaxis_title='Sequence length', yaxis_title='Time [s]', - template='plotly_white', font_size=22, margin=dict(l=100, r=20, t=100, b=100), - width=2000, height=1200) +fig.add_trace(go.Scatter(x=lengths, y=rnaStructure_dTs, mode="lines", name="RNAstructure")) +fig.add_trace(go.Scatter(x=lengths, y=efold_CPU_dTs, mode="lines", name="eFold (CPU)")) +fig.add_trace(go.Scatter(x=lengths, y=efold_GPU_dTs, mode="lines", name="eFold (GPU)")) +fig.update_layout( + title="Inference time of eFold vs RNAstructure", + xaxis_title="Sequence length", + yaxis_title="Time [s]", + template="plotly_white", + font_size=22, + margin=dict(l=100, r=20, t=100, b=100), + width=2000, + height=1200, +) # fig.show() - -fig.write_image(os.path.join(file_dir, 'speed_comparison.jpg')) - +file_dir = os.path.dirname(os.path.realpath(__file__)) +fig.write_image(os.path.join(file_dir, "speed_comparison.jpg"))