Skip to content

Commit 64a4c2b

Browse files
committed
refactor: configure with hydra
1 parent ec7685f commit 64a4c2b

7 files changed

Lines changed: 268 additions & 499 deletions

File tree

examples/predict.job

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ if [ x"${platform}" = x"slurm" ]; then
5656
node_opts=""
5757
else
5858
master_addr=${master_addr:-"127.0.0.1"}
59-
node_opts="--nnodes=${nnodes:-1} --node_rank=${node_rank:-0}"
59+
node_opts="+nnodes=${nnodes:-1} +node_rank=${node_rank:-0}"
6060
fi
6161
master_port=${master_port:-23456}
6262
echo "MasterAddr=${master_addr}:${master_port}"
6363
echo "==================================="
64-
node_opts="${node_opts} --init_method=tcp://${master_addr}:${master_port}"
64+
node_opts="${node_opts} +init_method=tcp://${master_addr}:${master_port}"
6565

6666
## init virtual environment if needed
6767
conda_home=${conda_home:-"${HOME}/.local/anaconda3"}
@@ -81,13 +81,13 @@ runner="python"
8181
if [ x"${platform}" = x"slurm" ]; then
8282
runner="srun ${runner}"
8383
fi
84-
${runner} ${PWD}/main.py ${node_opts} predict \
85-
--prefix=${CWD}/${exp}.pred${model_suffix} \
84+
${runner} ${PWD}/main.py predict \
85+
${node_opts} \
86+
+prefix=${CWD}/${exp}.pred${model_suffix} \
8687
\
87-
--models ${CWD}/${exp}.folding/model.pth${model_suffix} \
88-
--map_location=cpu \
89-
--model_recycles=2 \
90-
--model_shard_size=256 \
88+
+models=[${CWD}/${exp}.folding/model.pth${model_suffix}] \
89+
+model_recycles=2 \
90+
+model_shard_size=256 \
9191
\
92-
--fasta_fmt=single \
92+
+fasta_fmt=single \
9393
$*

install_env.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ conda install -y -c nvidia/label/cuda-${cuda_version} \
6565
conda install -y -c conda-forge \
6666
biopython \
6767
einops \
68+
hydra-core \
6869
tensorboard \
6970
tqdm \
7071
&& cleanup

profold2/command/evaluator.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
for further help.
66
"""
77
import os
8+
from dataclasses import dataclass, make_dataclass
89
import logging
910
import pickle
1011

@@ -17,18 +18,23 @@
1718
from profold2.model import functional, profiler, snapshot, FeatureBuilder, ReturnValues
1819
from profold2.utils import exists, timing
1920

20-
from profold2.command.worker import main, autocast_ctx, WorkerModel, WorkerXPU
21+
from profold2.command import worker
2122

2223

23-
def evaluate(rank, args): # pylint: disable=redefined-outer-name
24-
worker = WorkerModel(rank, args)
25-
feats, model = worker.load(args.model)
26-
features = FeatureBuilder(feats).to(worker.device())
24+
@dataclass
25+
class Args(worker.Args):
26+
pass
27+
28+
29+
def run(rank, args): # pylint: disable=redefined-outer-name
30+
runner = worker.WorkerModel(rank, args)
31+
feats, model = runner.load(args.model)
32+
features = FeatureBuilder(feats).to(runner.device())
2733
logging.info('feats: %s', feats)
2834

2935
kwargs = {}
30-
if rank.is_available() and WorkerXPU.world_size(args.nnodes) > 1:
31-
kwargs['num_replicas'] = WorkerXPU.world_size(args.nnodes)
36+
if rank.is_available() and worker.world_size(args.nnodes) > 1:
37+
kwargs['num_replicas'] = worker.world_size(args.nnodes)
3238
kwargs['rank'] = rank.rank
3339
test_loader = dataset.load(
3440
data_dir=args.eval_data,
@@ -68,7 +74,7 @@ def data_eval(idx, batch):
6874
# predict - out is (batch, L * 3, 3)
6975
with timing(f'Running model on {fasta_name} {fasta_len}', logging.debug):
7076
with torch.no_grad():
71-
with autocast_ctx(args.amp_enabled):
77+
with worker.autocast_ctx(args.amp_enabled):
7278
r = ReturnValues(
7379
**model(
7480
batch=batch, # pylint: disable=not-callable
@@ -320,34 +326,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
320326

321327
if __name__ == '__main__':
322328
import argparse
329+
import hydra
323330

324-
parser = argparse.ArgumentParser()
325-
326-
# init distributed env
327-
parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.')
328-
parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.')
329-
parser.add_argument(
330-
'--local_rank', type=int, default=None, help='local rank of xpu, default=None'
331-
)
332-
parser.add_argument(
333-
'--init_method',
334-
type=str,
335-
default='file:///tmp/profold2.dist',
336-
help='method to initialize the process group, '
337-
'default=\'file:///tmp/profold2.dist\''
331+
parser = argparse.ArgumentParser(
332+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
338333
)
339334

340-
# output dir
335+
parser.add_argument('-c', '--config', type=str, default=None, help='config file.')
341336
parser.add_argument(
342-
'-o',
343-
'--prefix',
344-
type=str,
345-
default='.',
346-
help='prefix of out directory, default=\'.\''
337+
'overrides',
338+
nargs='*',
339+
metavar='KEY=VAL',
340+
help='override configs, see: https://hydra.cc'
347341
)
348-
add_arguments(parser)
349-
parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
350342

351343
args = parser.parse_args()
352-
353-
main(args, evaluate)
344+
config_dir, config_name = os.path.split(
345+
os.path.abspath(args.config)
346+
) if exists(args.config) else (os.getcwd(), None)
347+
348+
with hydra.initialize_config_dir(
349+
version_base=None, config_dir=config_dir, job_name=__file__
350+
):
351+
worker.main(
352+
make_dataclass('t', [], namespace={
353+
'Args': Args,
354+
'run': run
355+
}), hydra.compose(config_name, args.overrides)
356+
)

profold2/command/main.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,75 +4,56 @@
44
```
55
for further help.
66
"""
7+
import os
78
import argparse
89

10+
import hydra
11+
912
from profold2.command import (evaluator, predictor, trainer, worker)
10-
from profold2.utils import env
13+
from profold2.utils import env, exists
1114

12-
_COMMANDS = [
13-
('train', trainer.train, trainer.add_arguments),
14-
('evaluate', evaluator.evaluate, evaluator.add_arguments),
15-
('predict', predictor.predict, predictor.add_arguments),
16-
]
15+
_COMMANDS = [('train', trainer), ('evaluate', evaluator), ('predict', predictor)]
1716

1817

1918
def create_args():
2019
formatter_class = argparse.ArgumentDefaultsHelpFormatter
2120
parser = argparse.ArgumentParser(formatter_class=formatter_class)
2221

23-
# distributed args
24-
parser.add_argument(
25-
'--nnodes',
26-
type=int,
27-
default=env('SLURM_NNODES', defval=None, dtype=int),
28-
help='number of nodes.'
29-
)
30-
parser.add_argument(
31-
'--node_rank',
32-
type=int,
33-
default=env('SLURM_NODEID', defval=0, dtype=int),
34-
help='rank of the node.'
35-
)
36-
parser.add_argument(
37-
'--local_rank',
38-
type=int,
39-
default=int(env('LOCAL_RANK', defval=0, dtype=int)),
40-
help='local rank of xpu.'
41-
)
42-
parser.add_argument(
43-
'--init_method',
44-
type=str,
45-
default=None,
46-
help='method to initialize the process group.'
47-
)
48-
4922
# command args
5023
subparsers = parser.add_subparsers(dest='command', required=True)
51-
for cmd, _, add_arguments in _COMMANDS:
24+
for cmd, _ in _COMMANDS:
5225
cmd_parser = subparsers.add_parser(cmd, formatter_class=formatter_class)
53-
54-
# output dir
5526
cmd_parser.add_argument(
56-
'-o', '--prefix', type=str, default='.', help='prefix of out directory.'
27+
'-c', '--config', type=str, default=None, help='config file.'
28+
)
29+
cmd_parser.add_argument(
30+
'overrides',
31+
nargs='*',
32+
metavar='KEY=VAL',
33+
help='override configs, see: https://hydra.cc'
5734
)
58-
add_arguments(cmd_parser)
59-
# verbose
60-
cmd_parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
6135

6236
return parser.parse_args()
6337

6438

65-
def create_fn(args): # pylint: disable=redefined-outer-name
66-
for cmd, fn, _ in _COMMANDS:
39+
def create_task(args): # pylint: disable=redefined-outer-name
40+
for cmd, task in _COMMANDS:
6741
if cmd == args.command:
68-
return fn
42+
return task
6943
return None
7044

7145

7246
def main():
7347
args = create_args()
74-
work_fn = create_fn(args)
75-
worker.main(args, work_fn)
48+
config_dir, config_name = os.path.split(
49+
os.path.abspath(args.config)
50+
) if exists(args.config) else (os.getcwd(), None)
51+
52+
with hydra.initialize_config_dir(
53+
version_base=None, config_dir=config_dir, job_name=args.command
54+
):
55+
task = create_task(args)
56+
worker.main(task, hydra.compose(config_name, args.overrides))
7657

7758

7859
if __name__ == '__main__':

profold2/command/predictor.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
for further help.
66
"""
77
import os
8+
from dataclasses import dataclass, field, make_dataclass
89
import functools
910
import glob
1011
import json
1112
import logging
13+
from typing import Optional
1214

1315
import numpy as np
1416
import torch
@@ -23,6 +25,7 @@
2325
from profold2.model import profiler, snapshot, FeatureBuilder, ReturnValues
2426
from profold2.utils import exists, timing
2527

28+
from profold2.command import worker
2629
from profold2.command.worker import main, autocast_ctx, WorkerModel, WorkerXPU
2730

2831

@@ -128,7 +131,33 @@ def _location_split(model_location):
128131
yield model_name, (features, model)
129132

130133

131-
def predict(rank, args): # pylint: disable=redefined-outer-name
134+
@dataclass
135+
class Args(worker.Args):
136+
models: list[str] = field(default_factory=list) # models to be loaded
137+
# using[model_name=model_location]
138+
# format
139+
model_recycles: int = 0 # number of recycles
140+
model_shard_size: Optional[int] = 0 # shard size in the evoformer model
141+
map_location: str = 'cpu' # remapped to an alternative set of devices
142+
143+
no_relaxer: bool = False # do NOT run relaxer
144+
no_gpu_relax: bool = False # force to run relax on cpu
145+
no_pth: bool = False # do NOT dump prediction headers
146+
147+
data_dir: Optional[str] = None # dataset dir
148+
data_idx: Optional[str] = None # dataset idx
149+
add_pseudo_linker: bool = False # enable loading complex data
150+
151+
fasta_files: list[str] = field(default_factory=list) # fasta files
152+
fasta_file_list: Optional[str] = None # file listing fasta files by line
153+
fasta_fmt: str = 'single' # single or a3m
154+
155+
num_workers: int = 1 # number of workers
156+
157+
max_msa_size: int = 1024
158+
159+
160+
def run(rank, args): # pylint: disable=redefined-outer-name
132161
model_runners = dict(_load_models(rank, args))
133162
logging.info('Have %d models: %s', len(model_runners), list(model_runners.keys()))
134163

@@ -360,35 +389,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
360389

361390
if __name__ == '__main__':
362391
import argparse
392+
import hydra
363393

364-
parser = argparse.ArgumentParser()
365-
366-
# init distributed env
367-
parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.')
368-
parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.')
369-
parser.add_argument(
370-
'--local_rank', type=int, default=None, help='local rank of xpu, default=None'
371-
)
372-
parser.add_argument(
373-
'--init_method',
374-
type=str,
375-
default='file:///tmp/profold2.dist',
376-
help='method to initialize the process group, '
377-
'default=\'file:///tmp/profold2.dist\''
394+
parser = argparse.ArgumentParser(
395+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
378396
)
379397

380-
# output dir
398+
parser.add_argument('-c', '--config', type=str, default=None, help='config file.')
381399
parser.add_argument(
382-
'-o',
383-
'--prefix',
384-
type=str,
385-
default='.',
386-
help='prefix of out directory, default=\'.\''
400+
'overrides',
401+
nargs='*',
402+
metavar='KEY=VAL',
403+
help='override configs, see: https://hydra.cc'
387404
)
388-
add_arguments(parser)
389-
# verbose
390-
parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
391405

392406
args = parser.parse_args()
393-
394-
main(args, predict)
407+
config_dir, config_name = os.path.split(
408+
os.path.abspath(args.config)
409+
) if exists(args.config) else (os.getcwd(), None)
410+
411+
with hydra.initialize_config_dir(
412+
version_base=None, config_dir=config_dir, job_name=__file__
413+
):
414+
worker.main(
415+
make_dataclass('t', [], namespace={
416+
'Args': Args,
417+
'run': run
418+
}), hydra.compose(config_name, args.overrides)
419+
)

0 commit comments

Comments
 (0)