|
5 | 5 | for further help. |
6 | 6 | """ |
7 | 7 | import os |
| 8 | +from dataclasses import dataclass, field, make_dataclass |
8 | 9 | import functools |
9 | 10 | import glob |
10 | 11 | import json |
11 | 12 | import logging |
| 13 | +from typing import Optional |
12 | 14 |
|
13 | 15 | import numpy as np |
14 | 16 | import torch |
|
23 | 25 | from profold2.model import profiler, snapshot, FeatureBuilder, ReturnValues |
24 | 26 | from profold2.utils import exists, timing |
25 | 27 |
|
| 28 | +from profold2.command import worker |
26 | 29 | from profold2.command.worker import main, autocast_ctx, WorkerModel, WorkerXPU |
27 | 30 |
|
28 | 31 |
|
@@ -128,7 +131,33 @@ def _location_split(model_location): |
128 | 131 | yield model_name, (features, model) |
129 | 132 |
|
130 | 133 |
|
131 | | -def predict(rank, args): # pylint: disable=redefined-outer-name |
| 134 | +@dataclass |
| 135 | +class Args(worker.Args): |
| 136 | + models: list[str] = field(default_factory=list) # models to be loaded |
| 137 | + # using[model_name=model_location] |
| 138 | + # format |
| 139 | + model_recycles: int = 0 # number of recycles |
| 140 | + model_shard_size: Optional[int] = 0 # shard size in the evoformer model |
| 141 | + map_location: str = 'cpu' # remapped to an alternative set of devices |
| 142 | + |
| 143 | + no_relaxer: bool = False # do NOT run relaxer |
| 144 | + no_gpu_relax: bool = False # force to run relax on cpu |
| 145 | + no_pth: bool = False # do NOT dump prediction headers |
| 146 | + |
| 147 | + data_dir: Optional[str] = None # dataset dir |
| 148 | + data_idx: Optional[str] = None # dataset idx |
| 149 | + add_pseudo_linker: bool = False # enable loading complex data |
| 150 | + |
| 151 | + fasta_files: list[str] = field(default_factory=list) # fasta files |
| 152 | + fasta_file_list: Optional[str] = None # file listing fasta files by line |
| 153 | + fasta_fmt: str = 'single' # single or a3m |
| 154 | + |
| 155 | + num_workers: int = 1 # number of workers |
| 156 | + |
| 157 | + max_msa_size: int = 1024 |
| 158 | + |
| 159 | + |
| 160 | +def run(rank, args): # pylint: disable=redefined-outer-name |
132 | 161 | model_runners = dict(_load_models(rank, args)) |
133 | 162 | logging.info('Have %d models: %s', len(model_runners), list(model_runners.keys())) |
134 | 163 |
|
@@ -360,35 +389,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name |
360 | 389 |
|
361 | 390 | if __name__ == '__main__': |
362 | 391 | import argparse |
| 392 | + import hydra |
363 | 393 |
|
364 | | - parser = argparse.ArgumentParser() |
365 | | - |
366 | | - # init distributed env |
367 | | - parser.add_argument('--nnodes', type=int, default=None, help='number of nodes.') |
368 | | - parser.add_argument('--node_rank', type=int, default=0, help='rank of the node.') |
369 | | - parser.add_argument( |
370 | | - '--local_rank', type=int, default=None, help='local rank of xpu, default=None' |
371 | | - ) |
372 | | - parser.add_argument( |
373 | | - '--init_method', |
374 | | - type=str, |
375 | | - default='file:///tmp/profold2.dist', |
376 | | - help='method to initialize the process group, ' |
377 | | - 'default=\'file:///tmp/profold2.dist\'' |
| 394 | + parser = argparse.ArgumentParser( |
| 395 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter |
378 | 396 | ) |
379 | 397 |
|
380 | | - # output dir |
| 398 | + parser.add_argument('-c', '--config', type=str, default=None, help='config file.') |
381 | 399 | parser.add_argument( |
382 | | - '-o', |
383 | | - '--prefix', |
384 | | - type=str, |
385 | | - default='.', |
386 | | - help='prefix of out directory, default=\'.\'' |
| 400 | + 'overrides', |
| 401 | + nargs='*', |
| 402 | + metavar='KEY=VAL', |
| 403 | + help='override configs, see: https://hydra.cc' |
387 | 404 | ) |
388 | | - add_arguments(parser) |
389 | | - # verbose |
390 | | - parser.add_argument('-v', '--verbose', action='store_true', help='verbose') |
391 | 405 |
|
392 | 406 | args = parser.parse_args() |
393 | | - |
394 | | - main(args, predict) |
| 407 | + config_dir, config_name = os.path.split( |
| 408 | + os.path.abspath(args.config) |
| 409 | + ) if exists(args.config) else (os.getcwd(), None) |
| 410 | + |
| 411 | + with hydra.initialize_config_dir( |
| 412 | + version_base=None, config_dir=config_dir, job_name=__file__ |
| 413 | + ): |
| 414 | + worker.main( |
| 415 | + make_dataclass('t', [], namespace={ |
| 416 | + 'Args': Args, |
| 417 | + 'run': run |
| 418 | + }), hydra.compose(config_name, args.overrides) |
| 419 | + ) |
0 commit comments