diff --git a/compiler_opt/es/blackbox_evaluator.py b/compiler_opt/es/blackbox_evaluator.py index 8e0f7833..d84e8ae6 100644 --- a/compiler_opt/es/blackbox_evaluator.py +++ b/compiler_opt/es/blackbox_evaluator.py @@ -28,6 +28,7 @@ from compiler_opt.es import blackbox_optimizers from compiler_opt.distributed import buffered_scheduler from compiler_opt.rl import compilation_runner +from compiler_opt import baseline_cache def _extract_results(futures: list[concurrent.futures.Future]) -> list[Any]: @@ -51,6 +52,8 @@ def __init__(self, *, train_corpus: corpus.Corpus, estimator_type: blackbox_optimizers.EstimatorType): self._train_corpus = train_corpus self._estimator_type = estimator_type + self._baseline_cache = baseline_cache.BaselineCache( + get_key=lambda x: x.name) @abc.abstractmethod def get_results( @@ -73,12 +76,17 @@ def __init__(self, num_ir_repeats_within_worker: int = 1, **kwargs): super().__init__(**kwargs) - self._samples: list[list[corpus.LoadedModuleSpec]] = [] self._total_num_perturbations = total_num_perturbations self._num_ir_repeats_within_worker = num_ir_repeats_within_worker - self._baselines: list[float | None] | None = None + self._reset() - def _load_samples(self) -> None: + def _reset(self): + # TODO: this object is currently supposed to respect a state transition + # and that makes it less maintainable than if not. + self._samples = None + self._baselines = None + + def load_samples(self) -> None: """Samples and loads modules if not already done. Ensures self._samples contains the expected number of loaded samples. @@ -89,6 +97,7 @@ def _load_samples(self) -> None: """ if self._samples: raise RuntimeError('Samples have already been loaded.') + self._samples = [] for _ in range(self._total_num_perturbations): samples = self._train_corpus.sample(self._num_ir_repeats_within_worker) loaded_samples = [ @@ -108,15 +117,15 @@ def _load_samples(self) -> None: if len(self._samples) != expected_count: raise RuntimeError('Some samples could not be loaded correctly.') - def _launch_compilation_workers(self, - pool: FixedWorkerPool, - perturbations: list[bytes] | None = None - ) -> list[concurrent.futures.Future]: - if self._samples is None: - raise RuntimeError('Loaded samples are not available.') + def _launch_compilation_workers( + self, + pool: FixedWorkerPool, + samples: list[list[corpus.LoadedModuleSpec]], + perturbations: list[bytes] | None = None + ) -> list[concurrent.futures.Future]: if perturbations is None: - perturbations = [None] * len(self._samples) - compile_args = zip(perturbations, self._samples) + perturbations = [None] * len(samples) + compile_args = zip(perturbations, samples) _, futures = buffered_scheduler.schedule_on_worker_pool( action=lambda w, args: w.compile(policy=args[0], modules=args[1]), jobs=compile_args, @@ -130,24 +139,43 @@ def _launch_compilation_workers(self, not_done, return_when=concurrent.futures.FIRST_COMPLETED) return futures + def ensure_baselines(self, pool): + if self._samples is None: + raise RuntimeError('Loaded samples are not available.') + # flatten the samples. + flat_samples = [item for sublist in self._samples for item in sublist] + + def _get_scores(some_list): + futures = self._launch_compilation_workers(pool, [[x] for x in some_list]) + return _extract_results(futures) + + baselines = self._baseline_cache.get_score(flat_samples, _get_scores) + + # TODO: the business of accummulating compilation results is now shared + # with the worker. + def sum_or_none(lst): + return sum(lst) if all(x is not None for x in lst) else None + + self._baselines = [ + sum_or_none(baselines[i:i + len(self._samples[i])]) + for i in range(len(self._samples)) + ] + def get_results( self, pool: FixedWorkerPool, perturbations: list[bytes]) -> list[concurrent.futures.Future]: - # We should have _samples by now. if not self._samples: - raise RuntimeError('Loaded samples are not available.') - return self._launch_compilation_workers(pool, perturbations) + self.load_samples() + self.ensure_baselines(pool) + return self._launch_compilation_workers(pool, self._samples, perturbations) def set_baseline(self, pool: FixedWorkerPool) -> None: - if self._baselines is not None: - raise RuntimeError('The baseline has already been set.') - self._load_samples() - results_futures = self._launch_compilation_workers(pool) - self._baselines = _extract_results(results_futures) + pass def get_rewards( self, results_futures: list[concurrent.futures.Future]) -> list[float | None]: + # we need a pool to get the baselines, so we should have gotten them already if self._baselines is None: raise RuntimeError('The baseline has not been set.') @@ -165,6 +193,7 @@ def get_rewards( else: rewards.append( compilation_runner.calculate_reward(policy_result, baseline)) + self._reset() return rewards diff --git a/compiler_opt/es/blackbox_evaluator_test.py b/compiler_opt/es/blackbox_evaluator_test.py index 1d68d2c9..0c71bd0d 100644 --- a/compiler_opt/es/blackbox_evaluator_test.py +++ b/compiler_opt/es/blackbox_evaluator_test.py @@ -61,8 +61,8 @@ def test_sampling_set_baseline(self): train_corpus=test_corpus, estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD, total_num_perturbations=1) - - evaluator.set_baseline(pool) + evaluator.load_samples() + evaluator.ensure_baselines(pool) # pylint: disable=protected-access self.assertAlmostEqual(evaluator._baselines, [10]) @@ -90,7 +90,8 @@ def test_sampling_get_rewards_with_baseline(self): estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD, total_num_perturbations=2) - evaluator.set_baseline(pool) + evaluator.load_samples() + evaluator.ensure_baselines(pool) f_policy1 = concurrent.futures.Future() f_policy1.set_result(1.5)