Skip to content

Commit 6b392be

Browse files
authored
feat: Add support for conditional search spaces in the tuner (#172)
# Summary Allows a custom function to be specified as the definition for the `Tuner` search space. This allows for more complex search space definitions, e.g. branching, looping. See https://docs.ray.io/en/latest/tune/examples/optuna_example.html#conditional-search-spaces for more information on how Ray provides this (via Optuna). # Changes * New `space` attribute on the `OptunaSpec` algorithm definition. * Implementation and tests.
1 parent 1784269 commit 6b392be

8 files changed

Lines changed: 291 additions & 49 deletions

File tree

docs/examples/tutorials/tuning-a-process.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,31 @@ Plugboard's YAML config supports an optional `tune` section, allowing you to def
7373
3. Parameters need to reference a type, so that Plugboard knows the type of parameter to build.
7474

7575
Now run `plugboard process tune model-with-tuner.yaml` to execute the optimisation job from the CLI.
76+
77+
## Advanced usage: complex search spaces
78+
79+
Occasionally you may need to define more complex search spaces, which go beyond what can be defined with a simple parameter configuration. For example:
80+
81+
* Conditional parameters, e.g. where parameter `a` must be greater than parameter `b`; or
82+
* Looping, e.g. building up a list of tunable parameters that is of variable length.
83+
84+
These conditional search space functions are supported by Ray Tune and can be defined as described in the [Ray documentation](https://docs.ray.io/en/latest/tune/examples/optuna_example.html#conditional-search-spaces). To use such a function you will need to:
85+
86+
1. Setup the [`Tuner`][plugboard.tune.Tuner], defining your parameters as usual;
87+
2. Write a custom function to define the search space, where each tunable parameter has a name of the form `"{component_name.field_or_arg_name}"`; then
88+
3. Supply your custom function to the `OptunaSpec` algorithm configuration.
89+
90+
For example, the following search space makes the velocity depend on the angle:
91+
```python
92+
--8<-- "examples/tutorials/006_optimisation/hello_tuner.py:custom_search_space"
93+
```
94+
95+
Then use this configuration to point the tuner to the `custom_space` function.
96+
```yaml
97+
--8<-- "examples/tutorials/006_optimisation/model-with-tuner-custom.yaml"
98+
```
99+
100+
1. We can reference a [`Process`][plugboard.process.Process] from another YAML file here to avoid repetition.
101+
2. Add the algorithm configuration here.
102+
103+
Then run using `plugboard process tune model-with-tuner-custom.yaml`.

examples/tutorials/006_optimisation/hello_tuner.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# fmt: off
44
import typing as _t
5-
5+
from optuna import Trial
66
from plugboard.component import Component, IOController as IO
77
from plugboard.process import ProcessBuilder
88
from plugboard.schemas import ComponentArgsDict, ProcessSpec, ProcessArgsSpec, ObjectiveSpec
@@ -64,6 +64,15 @@ async def step(self) -> None:
6464
# --8<-- [end:components]
6565

6666

67+
# --8<-- [start:custom_search_space]
68+
def custom_space(trial: Trial) -> dict[str, _t.Any] | None:
69+
"""Defines a custom search space for Optuna."""
70+
angle = trial.suggest_int("trajectory.angle", 0, 90)
71+
# Make velocity depend on angle
72+
trial.suggest_int("trajectory.velocity", angle, 100)
73+
# --8<-- [end:custom_search_space]
74+
75+
6776
if __name__ == "__main__":
6877
# --8<-- [start:define_process]
6978
process_spec = ProcessSpec(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
plugboard:
2+
process: model-with-tuner.yaml # (1)!
3+
tune:
4+
args:
5+
objective:
6+
object_name: max-height
7+
field_type: field
8+
field_name: max_y
9+
parameters:
10+
- type: ray.tune.uniform
11+
object_type: component
12+
object_name: trajectory
13+
field_type: arg
14+
field_name: angle
15+
lower: 0
16+
upper: 90
17+
- type: ray.tune.uniform
18+
object_type: component
19+
object_name: trajectory
20+
field_type: arg
21+
field_name: velocity
22+
lower: 0
23+
upper: 100
24+
num_samples: 40
25+
mode: max
26+
max_concurrent: 4
27+
algorithm: # (2)!
28+
type: ray.tune.search.optuna.OptunaSearch
29+
space: hello_tuner.custom_space
30+

plugboard/cli/process/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def tune(
129129
) -> None:
130130
"""Optimise a Plugboard process by adjusting its tunable parameters."""
131131
config_spec = _read_yaml(config)
132-
tuner = _build_tuner(config_spec)
132+
with add_sys_path(config.parent):
133+
tuner = _build_tuner(config_spec)
133134

134135
with Progress(
135136
SpinnerColumn("arrow3"),

plugboard/schemas/tune.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@ class OptunaSpec(PlugboardBaseModel):
1717
1818
Attributes:
1919
type: The algorithm type to load.
20+
space: Optional; A function defining the search space. Use this to define more complex
21+
search spaces that cannot be represented using the built-in parameter types.
2022
study_name: Optional; The name of the study.
2123
storage: Optional; The storage URI to save the optimisation results to.
2224
"""
2325

2426
type: _t.Literal["ray.tune.search.optuna.OptunaSearch"] = "ray.tune.search.optuna.OptunaSearch"
27+
space: str | None = None
2528
study_name: str | None = None
2629
storage: str | None = None
2730

plugboard/tune/tune.py

Lines changed: 103 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -55,25 +55,19 @@ def __init__(
5555
algorithm: Configuration for the underlying Optuna algorithm used for optimisation.
5656
"""
5757
self._logger = DI.logger.resolve_sync().bind(cls=self.__class__.__name__)
58-
# Check that objective and mode are lists of the same length if multiple objectives are used
58+
# Validate and normalize objective/mode
5959
self._check_objective(objective, mode)
60-
self._objective = objective if isinstance(objective, list) else [objective]
61-
self._mode = [str(m) for m in mode] if isinstance(mode, list) else str(mode)
62-
self._metric = (
63-
[obj.full_name for obj in self._objective]
64-
if len(self._objective) > 1
65-
else self._objective[0].full_name
60+
self._objective, self._mode, self._metric = self._normalize_objective_and_mode(
61+
objective, mode
6662
)
63+
self._custom_space = bool(algorithm and algorithm.space)
6764

68-
self._parameters_dict = {p.full_name: p for p in parameters}
69-
self._parameters = dict(self._build_parameter(p) for p in parameters)
70-
_algo = self._build_algorithm(algorithm)
71-
if max_concurrent is not None:
72-
_algo = ray.tune.search.ConcurrencyLimiter(_algo, max_concurrent)
73-
self._config = ray.tune.TuneConfig(
74-
num_samples=num_samples,
75-
search_alg=_algo,
76-
)
65+
# Prepare parameters and search algorithm
66+
self._parameters_dict, self._parameters = self._prepare_parameters(parameters)
67+
searcher = self._init_search_algorithm(algorithm, max_concurrent)
68+
69+
# Configure Ray Tune
70+
self._config = ray.tune.TuneConfig(num_samples=num_samples, search_alg=searcher)
7771
self._result_grid: _t.Optional[ray.tune.ResultGrid] = None
7872
self._logger.info("Tuner created")
7973

@@ -105,33 +99,58 @@ def _build_algorithm(
10599
) -> ray.tune.search.Searcher:
106100
if algorithm is None:
107101
self._logger.info("Using default Optuna search algorithm")
108-
return ray.tune.search.optuna.OptunaSearch(metric=self._metric, mode=self._mode)
109-
_algo_kwargs = {
110-
**algorithm.model_dump(exclude={"type"}),
111-
"mode": self._mode,
112-
"metric": self._metric,
113-
}
114-
115-
# Convert storage URI string to optuna storage object if needed
116-
# TODO: Make this more general to support other algorithms, e.g. use a builder class
117-
if "storage" in _algo_kwargs and isinstance(_algo_kwargs["storage"], str):
118-
_algo_kwargs["storage"] = optuna.storages.RDBStorage(url=_algo_kwargs["storage"])
119-
self._logger.info(
120-
"Converted storage URI to Optuna RDBStorage object",
121-
storage_uri=algorithm.storage,
122-
)
102+
return self._default_searcher()
123103

124-
algo_cls: _t.Optional[_t.Any] = locate(algorithm.type)
125-
if not algo_cls or not issubclass(algo_cls, ray.tune.search.searcher.Searcher):
126-
raise ValueError(f"Could not locate `Searcher` class {algorithm.type}")
104+
algo_kwargs = self._build_algo_kwargs(algorithm)
105+
algo_cls = self._get_algo_class(algorithm.type)
127106
self._logger.info(
128107
"Using custom search algorithm",
129108
algorithm=algorithm.type,
130-
params={
131-
k: v if k != "storage" else f"<{type(v).__name__}>" for k, v in _algo_kwargs.items()
132-
},
109+
params={k: self._mask_param_value(k, v) for k, v in algo_kwargs.items()},
133110
)
134-
return algo_cls(**_algo_kwargs)
111+
return algo_cls(**algo_kwargs)
112+
113+
def _default_searcher(self) -> "ray.tune.search.Searcher":
114+
return ray.tune.search.optuna.OptunaSearch(metric=self._metric, mode=self._mode)
115+
116+
def _build_algo_kwargs(self, algorithm: OptunaSpec) -> dict[str, _t.Any]:
117+
"""Prepare keyword args for the searcher, normalising storage/space."""
118+
kwargs = algorithm.model_dump(exclude={"type"})
119+
kwargs["mode"] = self._mode
120+
kwargs["metric"] = self._metric
121+
122+
storage = kwargs.get("storage")
123+
if isinstance(storage, str):
124+
kwargs["storage"] = optuna.storages.RDBStorage(url=storage)
125+
self._logger.info(
126+
"Converted storage URI to Optuna RDBStorage object",
127+
storage_uri=storage,
128+
)
129+
130+
space = kwargs.get("space")
131+
if space is not None:
132+
kwargs["space"] = self._resolve_space_fn(space)
133+
134+
return kwargs
135+
136+
def _resolve_space_fn(self, space: str) -> _t.Callable:
137+
space_fn = locate(space)
138+
if not space_fn or not isfunction(space_fn): # pragma: no cover
139+
raise ValueError(f"Could not locate search space function {space}")
140+
return space_fn
141+
142+
def _get_algo_class(self, type_path: str) -> _t.Type[ray.tune.search.searcher.Searcher]:
143+
algo_cls: _t.Optional[_t.Any] = locate(type_path)
144+
if not algo_cls or not issubclass(
145+
algo_cls, ray.tune.search.searcher.Searcher
146+
): # pragma: no cover
147+
raise ValueError(f"Could not locate `Searcher` class {type_path}")
148+
return algo_cls
149+
150+
def _mask_param_value(self, k: str, v: _t.Any) -> _t.Any:
151+
if k == "storage" or (k == "space" and isfunction(v)):
152+
return f"<{type(v).__name__}>"
153+
return v
135154

136155
def _build_parameter(
137156
self, parameter: ParameterSpec
@@ -203,12 +222,16 @@ def run(self, spec: ProcessSpec) -> ray.tune.Result | list[ray.tune.Result]:
203222
),
204223
)
205224

206-
self._logger.info("Setting Tuner with parameters", params=list(self._parameters.keys()))
207-
_tune = ray.tune.Tuner(
208-
trainable_with_resources,
209-
param_space=self._parameters,
210-
tune_config=self._config,
211-
)
225+
tuner_kwargs: dict[str, _t.Any] = {
226+
"tune_config": self._config,
227+
}
228+
if not self._custom_space:
229+
self._logger.info("Setting Tuner with parameters", params=list(self._parameters.keys()))
230+
tuner_kwargs["param_space"] = self._parameters
231+
else:
232+
self._logger.info("Setting Tuner with custom search space")
233+
234+
_tune = ray.tune.Tuner(trainable_with_resources, **tuner_kwargs)
212235
self._logger.info("Starting Tuner")
213236
self._result_grid = _tune.fit()
214237
self._logger.info("Tuner finished")
@@ -230,6 +253,10 @@ def fn(config: dict[str, _t.Any]) -> dict[str, _t.Any]: # pragma: no cover
230253
ComponentRegistry.add(cls, key=key)
231254

232255
for name, value in config.items():
256+
if name not in self._parameters_dict:
257+
# Custom search spaces may include intermediate parameters not in the Tuner
258+
self._logger.warning("Parameter from config not found in Tuner", param=name)
259+
continue
233260
self._override_parameter(spec, self._parameters_dict[name], value)
234261

235262
process = ProcessBuilder.build(spec)
@@ -253,3 +280,35 @@ def fn(config: dict[str, _t.Any]) -> dict[str, _t.Any]: # pragma: no cover
253280
return result
254281

255282
return fn
283+
284+
def _normalize_objective_and_mode(
285+
self,
286+
objective: ObjectiveSpec | list[ObjectiveSpec],
287+
mode: Direction | list[Direction],
288+
) -> tuple[list[ObjectiveSpec], str | list[str], str | list[str]]:
289+
"""Return normalized objectives, modes and metric name(s)."""
290+
objectives = objective if isinstance(objective, list) else [objective]
291+
modes = [str(m) for m in mode] if isinstance(mode, list) else str(mode)
292+
metric = (
293+
[obj.full_name for obj in objectives]
294+
if len(objectives) > 1
295+
else objectives[0].full_name
296+
)
297+
return objectives, modes, metric
298+
299+
def _prepare_parameters(
300+
self, parameters: list[ParameterSpec]
301+
) -> tuple[dict[str, ParameterSpec], dict[str, "ray.tune.search.sample.Sampler"]]:
302+
"""Build parameter lookup dict and Ray Tune parameter space."""
303+
params_dict = {p.full_name: p for p in parameters}
304+
params_space = dict(self._build_parameter(p) for p in parameters)
305+
return params_dict, params_space
306+
307+
def _init_search_algorithm(
308+
self, algorithm: _t.Optional[OptunaSpec], max_concurrent: _t.Optional[int]
309+
) -> "ray.tune.search.Searcher":
310+
"""Create the search algorithm and apply concurrency limits if requested."""
311+
algo = self._build_algorithm(algorithm)
312+
if max_concurrent is not None:
313+
algo = ray.tune.search.ConcurrencyLimiter(algo, max_concurrent)
314+
return algo
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
plugboard:
2+
process:
3+
args:
4+
components:
5+
- type: tests.integration.test_process_with_components_run.A
6+
args:
7+
name: "a"
8+
iters: 10
9+
- type: tests.integration.test_tuner.DynamicListComponent
10+
args:
11+
name: "d"
12+
list_param: [1.0, 2.0, 3.0]
13+
- type: tests.integration.test_process_with_components_run.C
14+
args:
15+
name: "c"
16+
path: "./c.txt"
17+
connectors:
18+
- source: "a.out_1"
19+
target: "d.in_1"
20+
- source: "d.out_1"
21+
target: "c.in_1"

0 commit comments

Comments
 (0)