Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Run Tests

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]

jobs:
test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Set up environment
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh

- name: Install dependencies
run: |
uv sync --group dev --extra cpu

- name: Run tests
run: |
uv run pytest --device=cpu
31 changes: 22 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,42 @@ HBDesigner is an algorithm that designs highly-connected hydrogen bonding networ

## Installation

HBDesigner can be installed using `mamba` or `Pixi`. The following will create a `mamba` environment called `hbdesigner` with all necessary dependencies.
HBDesigner can be installed using `mamba`, `uv`, or `Pixi`.
```
# first, clone the repo
git clone https://github.com/Kuhlman-Lab/HBDesigner.git
cd HBDesigner/
```
To create a virtual environment with `mamba` for use on a GPU or CPU, respectively:
```
# for running on a GPU
mamba env create -f env.yaml
pip install .

# for running on a CPU
mamba env create -f env_cpu.yaml
pip install .
```
To create a virtual environment with `uv` for use on a GPU or CPU respectively:
```
# for running on a GPU
uv pip install -e ".[gpu]"

# for running on a CPU
uv pip install -e ".[cpu]"
```
We also provide an alternative install method using `Pixi`. The following will create a `Pixi` project in the `HBDesigner` root directory:
```
git clone https://github.com/Kuhlman-Lab/HBDesigner.git
cd HBDesigner/
pixi install

# to install with pixi to run on a CPU instead use
pixi install -e cpu
```

The `Pixi` installation is much faster than `mamba`, but requires slightly more awkward syntax when running `HBDesigner`. See `examples/monomer/run_with_pixi` for an example.

HBDesigner can be installed using the provided `install_hbdesigner.sh` script. This script requires the `mamba` package manager, but it can be readily adapted to use other package managers.
```
git clone https://github.com/Kuhlman-Lab/HBDesigner.git
cd HBDesigner/
sh install_hbdesigner.sh
```

## Using HBDesigner

A detailed guide for running HBDesigner on your protein(s) of interest can be found at `examples/README.md`, along with many example runscripts for common design scenarios.
Expand All @@ -46,4 +59,4 @@ The HBDesigner source code and model weights are provided under an MIT license (
If you find HBDesigner useful for your own work, please use the following citation:
```
TBD
```
```
2 changes: 1 addition & 1 deletion env.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: hbdesigner_debug
name: hbdesigner
channels:
- conda-forge
dependencies:
Expand Down
190 changes: 190 additions & 0 deletions env_cpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
name: hbdesigner_cpu
channels:
- conda-forge
dependencies:
- _openmp_mutex=4.5=20_gnu
- bzip2=1.0.8=hda65f42_9
- ca-certificates=2026.2.25=hbd8a1cb_0
- icu=78.2=h33c6efd_0
- ld_impl_linux-64=2.45.1=default_hbd61a6d_101
- libexpat=2.7.4=hecca717_0
- libffi=3.5.2=h3435931_0
- libgcc=15.2.0=he0feb66_18
- libgcc-ng=15.2.0=h69a702a_18
- libgomp=15.2.0=he0feb66_18
- liblzma=5.8.2=hb03c661_0
- libnsl=2.0.1=hb9d3cd8_1
- libsqlite=3.51.2=hf4e2dac_0
- libstdcxx=15.2.0=h934c35e_18
- libuuid=2.41.3=h5347b49_0
- libxcrypt=4.4.36=hd590300_1
- libzlib=1.3.1=hb9d3cd8_2
- ncurses=6.5=h2d0b736_3
- openssl=3.6.1=h35e630c_1
- packaging=26.0=pyhcf101f3_0
- pip=26.0.1=pyh8b19718_0
- python=3.10.19=h3c07f61_3_cpython
- readline=8.3=h853b02a_0
- setuptools=82.0.0=pyh332efcf_0
- tk=8.6.13=noxft_h366c992_103
- wheel=0.46.3=pyhd8ed1ab_0
- zstd=1.5.7=hb78ec9c_6
- pip:
- --extra-index-url https://download.pytorch.org/whl/cpu
- --find-links https://data.pyg.org/whl/torch-2.8.0+cpu.html
- aiohappyeyeballs==2.6.1
- aiohttp==3.13.3
- aiosignal==1.4.0
- annotated-types==0.7.0
- antlr4-python3-runtime==4.9.3
- anyio==4.12.1
- argon2-cffi==25.1.0
- argon2-cffi-bindings==25.1.0
- arrow==1.4.0
- asttokens==3.0.1
- async-lru==2.2.0
- async-timeout==5.0.1
- attrs==25.4.0
- babel==2.18.0
- beautifulsoup4==4.14.3
- billiard==4.2.4
- biopython==1.86
- bleach==6.3.0
- blosc==1.11.4
- certifi==2026.2.25
- cffi==2.0.0
- charset-normalizer==3.4.4
- click==8.3.1
- cloudpickle==3.1.2
- comm==0.2.3
- cryptography==46.0.5
- dask==2026.1.2
- dask-jobqueue==0.9.0
- debugpy==1.8.20
- decorator==5.2.1
- defusedxml==0.7.1
- distributed==2026.1.2
- exceptiongroup==1.3.1
- executing==2.2.1
- fastjsonschema==2.21.2
- filelock==3.20.0
- fqdn==1.5.1
- frozenlist==1.8.0
- fsspec==2025.12.0
- gitdb==4.0.12
- gitpython==3.1.46
- h11==0.16.0
- httpcore==1.0.9
- httpx==0.28.1
- idna==3.11
- importlib-metadata==8.7.1
- ipykernel==7.2.0
- ipython==8.38.0
- ipywidgets==8.1.8
- isoduration==20.11.0
- jedi==0.19.2
- jinja2==3.1.6
- json5==0.13.0
- jsonpointer==3.0.0
- jsonschema==4.26.0
- jsonschema-specifications==2025.9.1
- jupyter==1.1.1
- jupyter-client==8.8.0
- jupyter-console==6.6.3
- jupyter-core==5.9.1
- jupyter-events==0.12.0
- jupyter-lsp==2.3.0
- jupyter-server==2.17.0
- jupyter-server-terminals==0.5.4
- jupyterlab==4.5.5
- jupyterlab-pygments==0.3.0
- jupyterlab-server==2.28.0
- jupyterlab-widgets==3.0.16
- lark==1.3.1
- locket==1.0.0
- markupsafe==3.0.2
- matplotlib-inline==0.2.1
- mistune==3.2.0
- mpmath==1.3.0
- msgpack==1.1.2
- multidict==6.7.1
- nbclient==0.10.4
- nbconvert==7.17.0
- nbformat==5.10.4
- nest-asyncio==1.6.0
- networkx==3.4.2
- notebook==7.5.4
- notebook-shim==0.2.4
- numpy==1.26.4
- omegaconf==2.3.0
- overrides==7.7.0
- pandas==2.3.3
- pandocfilters==1.5.1
- parso==0.8.6
- partd==1.4.2
- pebble==5.2.0
- pexpect==4.9.0
- platformdirs==4.9.2
- prometheus-client==0.24.1
- prompt-toolkit==3.0.52
- propcache==0.4.1
- protobuf==6.33.5
- psutil==7.2.2
- ptyprocess==0.7.0
- pure-eval==0.2.3
- pycparser==3.0
- pydantic==2.12.5
- pydantic-core==2.41.5
- pygments==2.19.2
- pyparsing==3.3.2
- --find-links https://west.rosettacommons.org/pyrosetta/quarterly/release.cxx11thread.serialization
- pyrosetta==2026.3
- python-dateutil==2.9.0.post0
- python-json-logger==4.0.0
- python-xz==0.6.0
- pytz==2025.2
- pyyaml==6.0.3
- pyzmq==27.1.0
- referencing==0.37.0
- requests==2.32.5
- rfc3339-validator==0.1.4
- rfc3986-validator==0.1.1
- rfc3987-syntax==1.1.0
- rpds-py==0.30.0
- scipy==1.15.3
- send2trash==2.1.0
- sentry-sdk==2.53.0
- six==1.17.0
- smmap==5.0.2
- sortedcontainers==2.4.0
- soupsieve==2.8.3
- stack-data==0.6.3
- sympy==1.14.0
- tblib==3.2.2
- terminado==0.18.1
- tinycss2==1.4.0
- tomli==2.4.0
- toolz==1.1.0
- torch==2.8.0+cpu
- torch-cluster==1.6.3+pt28cpu
- torch-geometric==2.7.0
- torch-scatter==2.1.2+pt28cpu
- tornado==6.5.4
- tqdm==4.67.3
- traitlets==5.14.3
- triton==3.4.0
- typing-extensions==4.15.0
- typing-inspection==0.4.2
- tzdata==2025.3
- uri-template==1.3.0
- urllib3==2.6.3
- wandb==0.25.0
- wcwidth==0.6.0
- webcolors==25.10.0
- webencodings==0.5.1
- websocket-client==1.9.0
- widgetsnbextension==4.0.15
- xxhash==3.6.0
- yarl==1.22.0
- zict==3.0.0
- zipp==3.23.0
19 changes: 13 additions & 6 deletions hbdesigner/inference/inference_hbdesigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)
from hbdesigner.scripts.train_hbdesigner import HBDesignerDataset
from hbdesigner.scripts.train_hbpacker import HBPackerDataset
from hbdesigner.utils import seed_everything
from hbdesigner.utils import seed_everything, get_autocast_context
from pyrosetta.rosetta.basic.options import set_real_option


Expand Down Expand Up @@ -244,9 +244,15 @@ def validate_inputs(self):
assert n_anchor_res > 0, "You must provide at least one anchor residue if --anchor_res is specified."

# Retrieve model weights and configurations
self.opts.pack_cfg = os.path.join(Path(__file__).parents[2], "model_weights/pack.yaml")
#self.opts.pack_cfg = os.path.join(Path(__file__).parents[2], "model_weights/pack.yaml")
pack_cfg_name = "pack_cpu.yaml" if self.opts.cpu else "pack.yaml"
design_cfg_name = (
f"{self.opts.design_model}_cpu.yaml" if self.opts.cpu else f"{self.opts.design_model}.yaml"
)
self.opts.pack_cfg = os.path.join(Path(__file__).parents[2], f"model_weights/{pack_cfg_name}")
self.opts.pack_ckpt = os.path.join(Path(__file__).parents[2], "model_weights/pack.pt")
self.opts.design_cfg = os.path.join(Path(__file__).parents[2], f"model_weights/design_020.yaml")
#self.opts.design_cfg = os.path.join(Path(__file__).parents[2], f"model_weights/design_020.yaml")
self.opts.design_cfg = os.path.join(Path(__file__).parents[2], f"model_weights/{design_cfg_name}")
self.opts.design_ckpt = os.path.join(Path(__file__).parents[2], f"model_weights/{self.opts.design_model}.pt")
# These packing options are fixed for inference use
self.opts.pack_crop = 10.0
Expand Down Expand Up @@ -673,9 +679,10 @@ def pack_with_hbpacker(
# Iterate over dataloader with multiproc enabled
while True:
batch = next(dl)
with torch.autocast(device_type="cuda", dtype=torch.float16):
dev = next(model.parameters()).device
with get_autocast_context(dev):
packs = model.run_pack_recyc(
batch.to("cuda" if torch.cuda.is_available() else "cpu"),
batch.to(dev),
n_recycles=model.cfg.model.hbpacker.num_recycles,
)
packs = packs.to("cpu")
Expand Down Expand Up @@ -964,7 +971,7 @@ def sample_from_hbdesigner(
f"Current temps (res/seq): {res_sample_temp_c:.3f}/{seq_sample_temp_c:.3f}"
)
# Mixed precision inference is ~50% faster w/o dropping any performance
with torch.autocast(device_type="cuda", dtype=torch.float16):
with get_autocast_context(dev):
results = model.sample_new(
batch.clone(),
res_sample_temp=res_sample_temp_c,
Expand Down
5 changes: 5 additions & 0 deletions hbdesigner/inference/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def get_hbdes_parser() -> FileArgumentParser:
choices=["design_002", "design_020"],
help="Design model to use. Default is 'design_020' (moderate noise), but 'design_002' (low noise) is also available.",
)
parser.add_argument(
"--cpu",
action="store_true",
help="Run inference on CPU by loading CPU-specific YAML configs.",
)
parser.add_argument(
"--out_dir",
type=str,
Expand Down
5 changes: 4 additions & 1 deletion hbdesigner/scripts/train_hbdesigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,10 @@ def test_loop(

def load_model_state(self, ckpt_path: str) -> None:
# Load weights from saved checkpoint.
map_location = {"cuda:0": f"cuda:{self.rank}"}
if self.device.type == "cuda" and torch.cuda.is_available():
map_location = {"cuda:0": f"cuda:{self.rank}"}
else:
map_location = "cpu"
state = torch.load(ckpt_path, map_location=map_location)

self.model.load_state_dict(state["model_state_dict"])
Expand Down
Loading
Loading