Skip to content

Commit e886c26

Browse files
committed
:refactor: configure with hydra
1 parent f6fb34c commit e886c26

7 files changed

Lines changed: 241 additions & 509 deletions

File tree

examples/predict.job

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ if [ x"${platform}" = x"slurm" ]; then
5656
node_opts=""
5757
else
5858
master_addr=${master_addr:-"127.0.0.1"}
59-
node_opts="--nnodes=${nnodes:-1} --node_rank=${node_rank:-0}"
59+
node_opts="+nnodes=${nnodes:-1} +node_rank=${node_rank:-0}"
6060
fi
6161
master_port=${master_port:-23456}
6262
echo "MasterAddr=${master_addr}:${master_port}"
6363
echo "==================================="
64-
node_opts="${node_opts} --init_method=tcp://${master_addr}:${master_port}"
64+
node_opts="${node_opts} +init_method=tcp://${master_addr}:${master_port}"
6565

6666
## init virtual environment if needed
6767
conda_home=${conda_home:-"${HOME}/.local/anaconda3"}
@@ -81,13 +81,13 @@ runner="python"
8181
if [ x"${platform}" = x"slurm" ]; then
8282
runner="srun ${runner}"
8383
fi
84-
${runner} ${PWD}/main.py ${node_opts} predict \
85-
--prefix=${CWD}/${exp}.pred${model_suffix} \
84+
${runner} ${PWD}/main.py predict \
85+
${node_opts} \
86+
+prefix=${CWD}/${exp}.pred${model_suffix} \
8687
\
87-
--models ${CWD}/${exp}.folding/model.pth${model_suffix} \
88-
--map_location=cpu \
89-
--model_recycles=2 \
90-
--model_shard_size=256 \
88+
+models=[${CWD}/${exp}.folding/model.pth${model_suffix}] \
89+
+model_recycles=2 \
90+
+model_shard_size=256 \
9191
\
92-
--fasta_fmt=single \
92+
+fasta_fmt=single \
9393
$*

install_env.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ conda install -y -c nvidia/label/cuda-${cuda_version} \
6565
conda install -y -c conda-forge \
6666
biopython \
6767
einops \
68+
hydra-core \
6869
tensorboard \
6970
tqdm \
7071
&& cleanup

profold2/command/evaluator.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
for further help.
66
"""
77
import os
8+
from dataclasses import dataclass, make_dataclass
89
import logging
910
import pickle
1011

@@ -22,7 +23,12 @@
2223
from profold2.command import worker
2324

2425

25-
def evaluate(rank, args): # pylint: disable=redefined-outer-name
26+
@dataclass
27+
class Args(worker.Args):
28+
pass
29+
30+
31+
def run(rank, args): # pylint: disable=redefined-outer-name
2632
wm = worker.WorkerModel(rank, args)
2733
feats, model = wm.load(args.model)
2834
features = FeatureBuilder(feats).to(wm.device())
@@ -323,34 +329,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
323329

324330
if __name__ == '__main__':
325331
import argparse
332+
import hydra
326333

327-
parser = argparse.ArgumentParser()
328-
329-
# init distributed env
330-
parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.')
331-
parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.')
332-
parser.add_argument(
333-
'--local_rank', type=int, default=None, help='local rank of xpu, default=None'
334-
)
335-
parser.add_argument(
336-
'--init_method',
337-
type=str,
338-
default='file:///tmp/profold2.dist',
339-
help='method to initialize the process group, '
340-
'default=\'file:///tmp/profold2.dist\''
334+
parser = argparse.ArgumentParser(
335+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
341336
)
342337

343-
# output dir
338+
parser.add_argument('-c', '--config', type=str, default=None, help='config file.')
344339
parser.add_argument(
345-
'-o',
346-
'--prefix',
347-
type=str,
348-
default='.',
349-
help='prefix of out directory, default=\'.\''
340+
'overrides',
341+
nargs='*',
342+
metavar='KEY=VAL',
343+
help='override configs, see: https://hydra.cc'
350344
)
351-
add_arguments(parser)
352-
parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
353345

354346
args = parser.parse_args()
355-
356-
worker.main(args, evaluate)
347+
config_dir, config_name = os.path.split(
348+
os.path.abspath(args.config)
349+
) if exists(args.config) else (os.getcwd(), None)
350+
351+
with hydra.initialize_config_dir(
352+
version_base=None, config_dir=config_dir, job_name=__file__
353+
):
354+
worker.main(
355+
make_dataclass('t', [], namespace={
356+
'Args': Args,
357+
'run': run
358+
}), hydra.compose(config_name, args.overrides)
359+
)

profold2/command/main.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,75 +4,56 @@
44
```
55
for further help.
66
"""
7+
import os
78
import argparse
89

10+
import hydra
11+
912
from profold2.command import (evaluator, predictor, trainer, worker)
10-
from profold2.utils import env
13+
from profold2.utils import env, exists
1114

12-
_COMMANDS = [
13-
('train', trainer.train, trainer.add_arguments),
14-
('evaluate', evaluator.evaluate, evaluator.add_arguments),
15-
('predict', predictor.predict, predictor.add_arguments),
16-
]
15+
_COMMANDS = [('train', trainer), ('evaluate', evaluator), ('predict', predictor)]
1716

1817

1918
def create_args():
2019
formatter_class = argparse.ArgumentDefaultsHelpFormatter
2120
parser = argparse.ArgumentParser(formatter_class=formatter_class)
2221

23-
# distributed args
24-
parser.add_argument(
25-
'--nnodes',
26-
type=int,
27-
default=env('SLURM_NNODES', defval=None, dtype=int),
28-
help='number of nodes.'
29-
)
30-
parser.add_argument(
31-
'--node_rank',
32-
type=int,
33-
default=env('SLURM_NODEID', defval=0, dtype=int),
34-
help='rank of the node.'
35-
)
36-
parser.add_argument(
37-
'--local_rank',
38-
type=int,
39-
default=int(env('LOCAL_RANK', defval=0, dtype=int)),
40-
help='local rank of xpu.'
41-
)
42-
parser.add_argument(
43-
'--init_method',
44-
type=str,
45-
default=None,
46-
help='method to initialize the process group.'
47-
)
48-
4922
# command args
5023
subparsers = parser.add_subparsers(dest='command', required=True)
51-
for cmd, _, add_arguments in _COMMANDS:
24+
for cmd, _ in _COMMANDS:
5225
cmd_parser = subparsers.add_parser(cmd, formatter_class=formatter_class)
53-
54-
# output dir
5526
cmd_parser.add_argument(
56-
'-o', '--prefix', type=str, default='.', help='prefix of out directory.'
27+
'-c', '--config', type=str, default=None, help='config file.'
28+
)
29+
cmd_parser.add_argument(
30+
'overrides',
31+
nargs='*',
32+
metavar='KEY=VAL',
33+
help='override configs, see: https://hydra.cc'
5734
)
58-
add_arguments(cmd_parser)
59-
# verbose
60-
cmd_parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
6135

6236
return parser.parse_args()
6337

6438

65-
def create_fn(args): # pylint: disable=redefined-outer-name
66-
for cmd, fn, _ in _COMMANDS:
39+
def create_task(args): # pylint: disable=redefined-outer-name
40+
for cmd, task in _COMMANDS:
6741
if cmd == args.command:
68-
return fn
42+
return task
6943
return None
7044

7145

7246
def main():
7347
args = create_args()
74-
work_fn = create_fn(args)
75-
worker.main(args, work_fn)
48+
config_dir, config_name = os.path.split(
49+
os.path.abspath(args.config)
50+
) if exists(args.config) else (os.getcwd(), None)
51+
52+
with hydra.initialize_config_dir(
53+
version_base=None, config_dir=config_dir, job_name=args.command
54+
):
55+
task = create_task(args)
56+
worker.main(task, hydra.compose(config_name, args.overrides))
7657

7758

7859
if __name__ == '__main__':

profold2/command/predictor.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
for further help.
66
"""
77
import os
8+
from dataclasses import dataclass, field, make_dataclass
89
import functools
910
import glob
1011
import json
1112
import logging
13+
from typing import Optional
1214

1315
import numpy as np
1416
import torch
@@ -128,7 +130,33 @@ def _location_split(model_location):
128130
yield model_name, (features, model)
129131

130132

131-
def predict(rank, args): # pylint: disable=redefined-outer-name
133+
@dataclass
134+
class Args(worker.Args):
135+
models: list[str] = field(default_factory=list) # models to be loaded
136+
# using[model_name=model_location]
137+
# format
138+
model_recycles: int = 0 # number of recycles
139+
model_shard_size: Optional[int] = 0 # shard size in the evoformer model
140+
map_location: str = 'cpu' # remapped to an alternative set of devices
141+
142+
no_relaxer: bool = False # do NOT run relaxer
143+
no_gpu_relax: bool = False # force to run relax on cpu
144+
no_pth: bool = False # do NOT dump prediction headers
145+
146+
data_dir: Optional[str] = None # dataset dir
147+
data_idx: Optional[str] = None # dataset idx
148+
add_pseudo_linker: bool = False # enable loading complex data
149+
150+
fasta_files: list[str] = field(default_factory=list) # fasta files
151+
fasta_file_list: Optional[str] = None # file listing fasta files by line
152+
fasta_fmt: str = 'single' # single or a3m
153+
154+
num_workers: int = 1 # number of workers
155+
156+
max_msa_size: int = 1024
157+
158+
159+
def run(rank, args): # pylint: disable=redefined-outer-name
132160
model_runners = dict(_load_models(rank, args))
133161
logging.info('Have %d models: %s', len(model_runners), list(model_runners.keys()))
134162

@@ -360,35 +388,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
360388

361389
if __name__ == '__main__':
362390
import argparse
391+
import hydra
363392

364-
parser = argparse.ArgumentParser()
365-
366-
# init distributed env
367-
parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.')
368-
parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.')
369-
parser.add_argument(
370-
'--local_rank', type=int, default=None, help='local rank of xpu, default=None'
371-
)
372-
parser.add_argument(
373-
'--init_method',
374-
type=str,
375-
default='file:///tmp/profold2.dist',
376-
help='method to initialize the process group, '
377-
'default=\'file:///tmp/profold2.dist\''
393+
parser = argparse.ArgumentParser(
394+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
378395
)
379396

380-
# output dir
397+
parser.add_argument('-c', '--config', type=str, default=None, help='config file.')
381398
parser.add_argument(
382-
'-o',
383-
'--prefix',
384-
type=str,
385-
default='.',
386-
help='prefix of out directory, default=\'.\''
399+
'overrides',
400+
nargs='*',
401+
metavar='KEY=VAL',
402+
help='override configs, see: https://hydra.cc'
387403
)
388-
add_arguments(parser)
389-
# verbose
390-
parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
391404

392405
args = parser.parse_args()
393-
394-
worker.main(args, predict)
406+
config_dir, config_name = os.path.split(
407+
os.path.abspath(args.config)
408+
) if exists(args.config) else (os.getcwd(), None)
409+
410+
with hydra.initialize_config_dir(
411+
version_base=None, config_dir=config_dir, job_name=__file__
412+
):
413+
worker.main(
414+
make_dataclass('t', [], namespace={
415+
'Args': Args,
416+
'run': run
417+
}), hydra.compose(config_name, args.overrides)
418+
)

0 commit comments

Comments
 (0)