-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptimize.py
More file actions
134 lines (103 loc) · 5.59 KB
/
optimize.py
File metadata and controls
134 lines (103 loc) · 5.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import optuna
from optuna.trial import TrialState
from optuna.visualization import plot_optimization_history, plot_param_importances
import argparse
import sys
# Importing tool modules
from tools import goldrush, test
# Function to get the Optuna sampler based on user input
def optuna_get_sampler(sampler_name):
if sampler_name == "random":
return optuna.samplers.RandomSampler()
elif sampler_name == "tpe":
return optuna.samplers.TPESampler()
elif sampler_name == "cmaes":
return optuna.samplers.CmaEsSampler()
else:
raise ValueError(f"Unknown sampler: {sampler_name}")
# Function to get the Optuna pruner based on user input
def optuna_get_pruner(pruner_name):
if pruner_name == "median":
return optuna.pruners.MedianPruner()
elif pruner_name == "nop":
return optuna.pruners.NopPruner()
elif pruner_name == "halving":
return optuna.pruners.SuccessiveHalvingPruner()
elif pruner_name == "hyperband":
return optuna.pruners.HyperbandPruner()
else:
raise ValueError(f"Unknown pruner: {pruner_name}")
def parse_arguments():
parser = argparse.ArgumentParser(description="Optimize hyperparameters for bioinformatics tools using Optuna.")
subparsers = parser.add_subparsers(dest='mode', required=True, help='Choose a tool to optimize hyperparameters for.')
# Shared arguments
def add_shared_arguments(subparser):
subparser.add_argument("--sampler", type=str, choices=['random', 'tpe', 'cmaes'], default='tpe', help='Sampler to be used for hyperparameter optimization. Default is "tpe".')
subparser.add_argument("--pruner", type=str, choices=['median', 'nop', 'halving', 'hyperband'], default='nop', help='Pruner to be used for hyperparameter optimization. Default is "nop" meaning No Pruner.')
subparser.add_argument("-n", "--n_trials", type=int, default=100, help="Number of trials for optimization. Default is 100.")
subparser.add_argument('--seed', type=int, default=192, help='Random seed for reproducibility.')
subparser.add_argument("-d", "--direction", type=str, choices=['minimize', 'maximize'], default='minimize', help="Direction of optimization. Default is 'minimize'.")
subparser.add_argument("--storage", type=str, default=None, help="Database URL for Optuna. (default='None'). If you're running experiments that you don't wish to persist, consider using Optuna's in-memory storage: 'sqlite:///:memory:', otherwise select a db name: 'sqlite:///goldrush_optuna.db' for exsample.")
subparser.add_argument("-s", "--study_name", type=str, default="biooptuna_study", help="Name of the Optuna study. Default is 'biooptuna_study'.")
subparser.add_argument("--plot", action='store_true', help="Plot the optimization history and parameter importances.")
# Goldrush subparser
goldrush_parser = subparsers.add_parser('goldrush', help='Optimize hyperparameters for Goldrush.')
add_shared_arguments(goldrush_parser)
# Test subparser
test_parser = subparsers.add_parser('test', help='Optimize hyperparameters for test function.')
add_shared_arguments(test_parser)
# Add more subparser as needed (e.g. for other bioinformatics tools developed within BTL lab)
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
sys.exit(1)
return parser.parse_args()
def main(args):
# Prinout the mode, arguments and their values
print(f"Optimizing hyperparameters for {args.mode} with the following arguments:")
for arg in vars(args):
print(f"{arg}: {getattr(args, arg)}")
# Depending on the mode, select the appropriate tool module
if args.mode == 'goldrush':
tool_module = goldrush
elif args.mode == 'test':
tool_module = test
# ... handle other sub-commands ...
else:
raise ValueError("Invalid mode. Please specify a valid mode.")
sampler = optuna_get_sampler(args.sampler)
pruner = optuna_get_pruner(args.pruner)
storage_name = args.storage
study_name = args.study_name
# Define the study and objective
def objective(trial):
params = tool_module.define_search_space(trial)
return tool_module.objective(params)
def objective_with_pruning(trial):
params = tool_module.define_search_space(trial)
return tool_module.objective_with_pruning(params)
if storage_name is None:
study = optuna.create_study(direction=args.direction, sampler=sampler, pruner=pruner, study_name=study_name)
else:
study = optuna.create_study(direction=args.direction, sampler=sampler, pruner=pruner, storage=storage_name, study_name=study_name, load_if_exists=True)
if args.pruner == "nop":
study.optimize(objective, n_trials=args.n_trials)
else:
study.optimize(objective_with_pruning, n_trials=args.n_trials)
if args.plot:
plot_param_importances(study).show()
plot_optimization_history(study).show()
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print("Study statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
if __name__ == "__main__":
args = parse_arguments()
main(args)