Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ efold -h
### Using python

```python
>>> from efold import inference
>>> inference('AAACAUGAGGAUUACCCAUGU', fmt='dotbracket')
>>> from efold.api import run
>>> run.run('AAACAUGAGGAUUACCCAUGU', fmt='dotbracket')
..(((((.((....)))))))
```

Expand Down
5 changes: 0 additions & 5 deletions efold/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from .models import create_model
from .core import *
from .util import *
from .config import *
from .api import *
1 change: 0 additions & 1 deletion efold/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .run import run as inference
77 changes: 45 additions & 32 deletions efold/api/run.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import os
from typing import List, Union
from ..models import create_model
import torch
from os.path import join, dirname
from ..core import batch
from ..core.embeddings import sequence_to_int
from ..core.postprocess import Postprocess
from os.path import dirname, join
from typing import List, Optional, Union

import numpy as np
from ..util.format_conversion import convert_bp_list_to_dotbracket
import torch

from efold.core import batch, embeddings, postprocess
from efold.models import factory
from efold.util import format_conversion

torch.set_default_dtype(torch.float32)

postprocesser = Postprocess()
postprocesser = postprocess.Postprocess()

def _load_sequences_from_fasta(fasta:str):

def _load_sequences_from_fasta(fasta: str) -> list[str]:
with open(fasta, "r") as f:
lines = f.readlines()
sequences = []
Expand All @@ -24,36 +25,44 @@ def _load_sequences_from_fasta(fasta:str):
sequences[-1] += line.strip()
return sequences

def _predict_structure(model, sequence:str, device='cpu'):

seq = sequence_to_int(sequence).unsqueeze(0)
def _predict_structure(model, sequence: str, device: str = "cpu") -> list[tuple[int, int]]:
seq = embeddings.sequence_to_int(sequence).unsqueeze(0)
b = batch.Batch(
sequence=seq,
reference=[""],
length=[len(seq)],
L = len(seq),
L=len(seq),
use_error=False,
batch_size=1,
data_types=["sequence"],
dt_count={"sequence": 1}).to(device)

dt_count={"sequence": 1},
).to(device)

# predict the structure
with torch.inference_mode():
pred = model(b)
structure = postprocesser.run(pred['structure'].to('cpu'), b.get('sequence').to('cpu')).numpy().round()[0]
structure = (
postprocesser.run(pred["structure"].to("cpu"), b.get("sequence").to("cpu"))
.numpy()
.round()[0]
)

# turn into 1-indexed base pairs
return [(b,c) for b, c in (np.stack(np.where(np.triu(structure) == 1)) + 1).T]
return [(b, c) for b, c in (np.stack(np.where(np.triu(structure) == 1)) + 1).T]


def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None):
def run(
arg: Optional[Union[str, List[str]]] = None, fmt: str = "dotbracket", device: Optional[str] = None
) -> dict[str, Union[str, list[tuple[int, int]]]]:
"""Runs the Efold API on the provided sequence or fasta file.

Args:
arg (str): The sequence or the list of sequences to run Efold on, or the path to a fasta file containing the sequences.
arg (str): The sequence or the list of sequences to run Efold on, or the path to a fasta file containing the sequences.

Returns:
dict: A dictionary containing the sequences as keys and the predicted secondary structures as values.

Examples:
>>> from efold.api.run import run
>>> structure = run("GGGAAAUCC") # this is awful, we need to remove the prints
Expand All @@ -66,9 +75,11 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None):
No scaling, use preLN
Replace GLU with swish for Conv
>>> assert structure == {'GGGAAAUCC': [(1, 9), (2, 8)]}, "Test failed: {}".format(structure)

"""
assert fmt in ["dotbracket", "basepair", 'bp'], "Invalid format. Must be either 'dotbracket' or 'basepair'"
assert fmt in ["dotbracket", "basepair", "bp"], (
"Invalid format. Must be either 'dotbracket' or 'basepair'"
)

# Check if the input is valid
if not arg:
Expand All @@ -77,13 +88,13 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None):
if not os.path.exists(arg):
raise ValueError("File not found")
sequences = _load_sequences_from_fasta(arg)
elif type(arg) == str:
elif isinstance(arg, str):
sequences = [arg]
elif hasattr(arg, "__iter__") and all([isinstance(s, str) for s in arg]):
sequences = arg
else:
raise ValueError("Either sequence or fasta must be provided")

# Get device
if not device:
if torch.cuda.is_available():
Expand All @@ -92,7 +103,7 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None):
device = torch.device("cpu")

# Load best model
model = create_model(
model = factory.create_model(
model="efold",
ntoken=5,
d_model=64,
Expand All @@ -105,17 +116,19 @@ def run(arg:Union[str, List[str]]=None, fmt="dotbracket", device=None):
weight_decay=0,
gamma=0.995,
)
model.load_state_dict(torch.load(join(dirname(dirname(__file__)), "resources/efold_weights.pt")), strict=False)
model.load_state_dict(
torch.load(join(dirname(dirname(__file__)), "resources/efold_weights.pt")), strict=False
)
model.eval()
model = model.to(device)

structures = []
for seq in sequences:
for seq in sequences:
structure = _predict_structure(model, seq, device=device)
if fmt == "dotbracket":
db_structure = convert_bp_list_to_dotbracket(structure, len(seq))
if db_structure != None:
db_structure = format_conversion.convert_bp_list_to_dotbracket(structure, len(seq))
if db_structure is not None:
structure = db_structure
structures.append(structure)

return {seq: structure for seq, structure in zip(sequences, structures)}
return {seq: structure for seq, structure in zip(sequences, structures)}
50 changes: 30 additions & 20 deletions efold/cli.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
import json

import click
from efold.api.run import run

@click.command('efold')
@click.argument('sequence', required=False, type=str)
@click.option('--fasta', '-f', help='Input FASTA file path')
@click.option('--output', '-o', default='output.txt', help='Output file path (json, txt or csv)', type=click.Path())
@click.option('--basepair/--dotbracket', '-bp/-db', default=False, help='Output structure format')
@click.option('--help', '-h', is_flag=True, help='Show this message', type=bool)
def cli(sequence, fasta, output, basepair, help):


from efold.api import run


@click.command("efold")
@click.argument("sequence", required=False, type=str)
@click.option("--fasta", "-f", help="Input FASTA file path")
@click.option(
"--output",
"-o",
default="output.txt",
help="Output file path (json, txt or csv)",
type=click.Path(),
)
@click.option("--basepair/--dotbracket", "-bp/-db", default=False, help="Output structure format")
@click.option("--help", "-h", is_flag=True, help="Show this message", type=bool)
def cli(sequence: str, fasta: str, output: str, basepair: bool, help: bool) -> None:
if help:
click.echo(cli.get_help(click.Context(cli)))
return
fmt = 'bp' if basepair else 'dotbracket'

fmt = "bp" if basepair else "dotbracket"
if sequence:
result = run(sequence, fmt)
result = run.run(sequence, fmt)
elif fasta:
result = run(fasta, fmt)
result = run.run(fasta, fmt)
else:
click.echo("Please provide either a sequence or a FASTA file.")
return

with open(output, 'w') as f:
file_fmt = output.split('.')[-1]
if file_fmt == 'json':
with open(output, "w") as f:
file_fmt = output.split(".")[-1]
if file_fmt == "json":
f.write(json.dumps(result, indent=4))
elif file_fmt == 'csv':
elif file_fmt == "csv":
import csv

writer = csv.writer(f)
writer.writerows(result.items())
else:
Expand All @@ -40,5 +49,6 @@ def cli(sequence, fasta, output, basepair, help):
click.echo()
click.echo(f"Output saved to {output}")

if __name__ == '__main__':
cli()

if __name__ == "__main__":
cli()
55 changes: 0 additions & 55 deletions efold/config.py

This file was deleted.

Loading