diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index ba84675425..f2d393994f 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -409,7 +409,7 @@ def __init__( # pylint: disable=too-many-arguments experiment_id: str, trial_id: int, tunable_config_id: int, - trial_runner_id: int | None = None, + trial_runner_id: int | None, opt_targets: dict[str, Literal["min", "max"]], config: dict[str, Any] | None = None, status: Status = Status.UNKNOWN, diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 62daa0232c..eb47de7d71 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -276,6 +276,7 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Stor experiment_id=self._experiment_id, trial_id=trial.trial_id, config_id=trial.config_id, + trial_runner_id=trial.trial_runner_id, opt_targets=self._opt_targets, config=config, ) @@ -350,6 +351,7 @@ def _new_trial( experiment_id=self._experiment_id, trial_id=self._trial_id, config_id=config_id, + trial_runner_id=None, # initially, Trials are not assigned to a TrialRunner opt_targets=self._opt_targets, config=config, status=new_trial_status, diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 88fc05ef7d..a9960c466a 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -38,7 +38,7 @@ def __init__( # pylint: disable=too-many-arguments experiment_id: str, trial_id: int, config_id: int, - trial_runner_id: int | None = None, + trial_runner_id: int | None, opt_targets: dict[str, Literal["min", "max"]], config: dict[str, Any] | None = None, status: Status = Status.UNKNOWN, diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index a1ab74f9f5..aaf545c787 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -47,6 +47,10 @@ def test_schedule_trial( # Schedule 2 hours in the future: trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config) + # Check that if we assign a TrialRunner that that value is still available on restore. + trial_now2.set_trial_runner(1) + assert trial_now2.trial_runner_id + exp_data = storage.experiments[exp_storage.experiment_id] trial_now1_data = exp_data.trials[trial_now1.trial_id] assert trial_now1_data.trial_runner_id is None @@ -54,6 +58,9 @@ def test_schedule_trial( # Check that Status matches in object vs. backend storage. assert trial_now1.status == trial_now1_data.status + trial_now2_data = exp_data.trials[trial_now2.trial_id] + assert trial_now2_data.trial_runner_id == trial_now2.trial_runner_id + # Scheduler side: get trials ready to run at certain timestamps: # Pretend 1 minute has passed, get trials scheduled to run: @@ -63,6 +70,16 @@ def test_schedule_trial( trial_now2.trial_id, } + # Make sure that the pending trials and trial_runner_ids match. + pending_trial_runner_ids = { + pending_trial.trial_id: pending_trial.trial_runner_id + for pending_trial in exp_storage.pending_trials(timestamp + timedelta_1min, running=False) + } + assert pending_trial_runner_ids == { + trial_now1.trial_id: trial_now1.trial_runner_id, + trial_now2.trial_id: trial_now2.trial_runner_id, + } + # Get trials scheduled to run within the next 1 hour: pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) assert pending_ids == {