From 828298fde89b7c4dd3a30e1e871020a026c2fb47 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Thu, 12 Feb 2026 21:25:44 -0800 Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20Ba?= =?UTF-8?q?se=20of=20Pull=20Request=20#550?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request: https://github.com/google/ml-compiler-opt/pull/551 --- compiler_opt/baseline_cache_test.py | 33 ++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/compiler_opt/baseline_cache_test.py b/compiler_opt/baseline_cache_test.py index 4885c8a9..d9cc9364 100644 --- a/compiler_opt/baseline_cache_test.py +++ b/compiler_opt/baseline_cache_test.py @@ -41,29 +41,43 @@ def track_score(lst): score_asked_for.extend(lst) return [mock[k] if k in mock else None for k in lst] + cache = baseline_cache.BaselineCache(get_key=lambda x: x) cache = baseline_cache.BaselineCache(get_key=lambda x: x) self.assertEmpty(cache.get_cache()) + self.assertEqual( + cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertEqual( cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertDictEqual(cache.get_cache(), {"b": 2, "c": 3}) self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) score_asked_for.clear() + self.assertEqual( + cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertEqual( cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertListEqual(score_asked_for, []) + self.assertEqual( + cache.get_score(["a", "c", "b"], get_scores_func=track_score), + [1, 3, 2]) self.assertEqual( cache.get_score(["a", "c", "b"], get_scores_func=track_score), [1, 3, 2]) self.assertListEqual(score_asked_for, ["a"]) score_asked_for.clear() + self.assertEqual( + cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), + [1, None, 3, 2]) self.assertEqual( cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), [1, None, 3, 2]) self.assertListEqual(score_asked_for, ["n"]) score_asked_for.clear() + self.assertEqual( + cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), + [1, None, 3, 2]) self.assertEqual( cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), [1, None, 3, 2]) @@ -84,6 +98,21 @@ def track_score(lst): [3, 2, 3, 2]) self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) + def test_duplicates(self): + mock = {"a": 1, "b": 2, "c": 3} + score_asked_for = [] + + def track_score(lst): + score_asked_for.extend(lst) + return [mock[k] if k in mock else None for k in lst] + + cache = baseline_cache.BaselineCache(get_key=lambda x: x) + self.assertEmpty(cache.get_cache()) + self.assertEqual( + cache.get_score(["c", "b", "c", "b"], get_scores_func=track_score), + [3, 2, 3, 2]) + self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) + def test_with_workers(self): with local_worker_manager.LocalWorkerPoolManager( worker_class=MockWorker, count=4) as lwm: @@ -100,12 +129,14 @@ def get_scores(items: list[str]): futures, return_when=concurrent.futures.ALL_COMPLETED) return [f.result() if f.exception() is None else None for f in futures] + cache = baseline_cache.BaselineCache(get_key=lambda x: x) cache = baseline_cache.BaselineCache(get_key=lambda x: x) self.assertEmpty(cache.get_cache()) self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) - self.assertListEqual(sorted(score_asked_for), sorted(["4", "2"])) + self.assertListEqual(score_asked_for, ["4", "2"]) self.assertDictEqual(cache.get_cache(), {"4": 4, "2": 2}) score_asked_for.clear() + self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) self.assertListEqual(score_asked_for, []) From 5f3a0f759887d6fd73501d52d9a102ca10222344 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Thu, 12 Feb 2026 21:25:44 -0800 Subject: [PATCH 2/2] Use the baseline cache for the sampling evaluator Reviewers: boomanaiden154 Pull Request: https://github.com/google/ml-compiler-opt/pull/550 --- compiler_opt/baseline_cache_test.py | 33 +---------- compiler_opt/es/blackbox_evaluator.py | 67 ++++++++++++++++------ compiler_opt/es/blackbox_evaluator_test.py | 7 ++- 3 files changed, 53 insertions(+), 54 deletions(-) diff --git a/compiler_opt/baseline_cache_test.py b/compiler_opt/baseline_cache_test.py index d9cc9364..4885c8a9 100644 --- a/compiler_opt/baseline_cache_test.py +++ b/compiler_opt/baseline_cache_test.py @@ -41,43 +41,29 @@ def track_score(lst): score_asked_for.extend(lst) return [mock[k] if k in mock else None for k in lst] - cache = baseline_cache.BaselineCache(get_key=lambda x: x) cache = baseline_cache.BaselineCache(get_key=lambda x: x) self.assertEmpty(cache.get_cache()) - self.assertEqual( - cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertEqual( cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertDictEqual(cache.get_cache(), {"b": 2, "c": 3}) self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) score_asked_for.clear() - self.assertEqual( - cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertEqual( cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2]) self.assertListEqual(score_asked_for, []) - self.assertEqual( - cache.get_score(["a", "c", "b"], get_scores_func=track_score), - [1, 3, 2]) self.assertEqual( cache.get_score(["a", "c", "b"], get_scores_func=track_score), [1, 3, 2]) self.assertListEqual(score_asked_for, ["a"]) score_asked_for.clear() - self.assertEqual( - cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), - [1, None, 3, 2]) self.assertEqual( cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), [1, None, 3, 2]) self.assertListEqual(score_asked_for, ["n"]) score_asked_for.clear() - self.assertEqual( - cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), - [1, None, 3, 2]) self.assertEqual( cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score), [1, None, 3, 2]) @@ -98,21 +84,6 @@ def track_score(lst): [3, 2, 3, 2]) self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) - def test_duplicates(self): - mock = {"a": 1, "b": 2, "c": 3} - score_asked_for = [] - - def track_score(lst): - score_asked_for.extend(lst) - return [mock[k] if k in mock else None for k in lst] - - cache = baseline_cache.BaselineCache(get_key=lambda x: x) - self.assertEmpty(cache.get_cache()) - self.assertEqual( - cache.get_score(["c", "b", "c", "b"], get_scores_func=track_score), - [3, 2, 3, 2]) - self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"])) - def test_with_workers(self): with local_worker_manager.LocalWorkerPoolManager( worker_class=MockWorker, count=4) as lwm: @@ -129,14 +100,12 @@ def get_scores(items: list[str]): futures, return_when=concurrent.futures.ALL_COMPLETED) return [f.result() if f.exception() is None else None for f in futures] - cache = baseline_cache.BaselineCache(get_key=lambda x: x) cache = baseline_cache.BaselineCache(get_key=lambda x: x) self.assertEmpty(cache.get_cache()) self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) - self.assertListEqual(score_asked_for, ["4", "2"]) + self.assertListEqual(sorted(score_asked_for), sorted(["4", "2"])) self.assertDictEqual(cache.get_cache(), {"4": 4, "2": 2}) score_asked_for.clear() - self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2]) self.assertListEqual(score_asked_for, []) diff --git a/compiler_opt/es/blackbox_evaluator.py b/compiler_opt/es/blackbox_evaluator.py index 8e0f7833..d84e8ae6 100644 --- a/compiler_opt/es/blackbox_evaluator.py +++ b/compiler_opt/es/blackbox_evaluator.py @@ -28,6 +28,7 @@ from compiler_opt.es import blackbox_optimizers from compiler_opt.distributed import buffered_scheduler from compiler_opt.rl import compilation_runner +from compiler_opt import baseline_cache def _extract_results(futures: list[concurrent.futures.Future]) -> list[Any]: @@ -51,6 +52,8 @@ def __init__(self, *, train_corpus: corpus.Corpus, estimator_type: blackbox_optimizers.EstimatorType): self._train_corpus = train_corpus self._estimator_type = estimator_type + self._baseline_cache = baseline_cache.BaselineCache( + get_key=lambda x: x.name) @abc.abstractmethod def get_results( @@ -73,12 +76,17 @@ def __init__(self, num_ir_repeats_within_worker: int = 1, **kwargs): super().__init__(**kwargs) - self._samples: list[list[corpus.LoadedModuleSpec]] = [] self._total_num_perturbations = total_num_perturbations self._num_ir_repeats_within_worker = num_ir_repeats_within_worker - self._baselines: list[float | None] | None = None + self._reset() - def _load_samples(self) -> None: + def _reset(self): + # TODO: this object is currently supposed to respect a state transition + # and that makes it less maintainable than if not. + self._samples = None + self._baselines = None + + def load_samples(self) -> None: """Samples and loads modules if not already done. Ensures self._samples contains the expected number of loaded samples. @@ -89,6 +97,7 @@ def _load_samples(self) -> None: """ if self._samples: raise RuntimeError('Samples have already been loaded.') + self._samples = [] for _ in range(self._total_num_perturbations): samples = self._train_corpus.sample(self._num_ir_repeats_within_worker) loaded_samples = [ @@ -108,15 +117,15 @@ def _load_samples(self) -> None: if len(self._samples) != expected_count: raise RuntimeError('Some samples could not be loaded correctly.') - def _launch_compilation_workers(self, - pool: FixedWorkerPool, - perturbations: list[bytes] | None = None - ) -> list[concurrent.futures.Future]: - if self._samples is None: - raise RuntimeError('Loaded samples are not available.') + def _launch_compilation_workers( + self, + pool: FixedWorkerPool, + samples: list[list[corpus.LoadedModuleSpec]], + perturbations: list[bytes] | None = None + ) -> list[concurrent.futures.Future]: if perturbations is None: - perturbations = [None] * len(self._samples) - compile_args = zip(perturbations, self._samples) + perturbations = [None] * len(samples) + compile_args = zip(perturbations, samples) _, futures = buffered_scheduler.schedule_on_worker_pool( action=lambda w, args: w.compile(policy=args[0], modules=args[1]), jobs=compile_args, @@ -130,24 +139,43 @@ def _launch_compilation_workers(self, not_done, return_when=concurrent.futures.FIRST_COMPLETED) return futures + def ensure_baselines(self, pool): + if self._samples is None: + raise RuntimeError('Loaded samples are not available.') + # flatten the samples. + flat_samples = [item for sublist in self._samples for item in sublist] + + def _get_scores(some_list): + futures = self._launch_compilation_workers(pool, [[x] for x in some_list]) + return _extract_results(futures) + + baselines = self._baseline_cache.get_score(flat_samples, _get_scores) + + # TODO: the business of accummulating compilation results is now shared + # with the worker. + def sum_or_none(lst): + return sum(lst) if all(x is not None for x in lst) else None + + self._baselines = [ + sum_or_none(baselines[i:i + len(self._samples[i])]) + for i in range(len(self._samples)) + ] + def get_results( self, pool: FixedWorkerPool, perturbations: list[bytes]) -> list[concurrent.futures.Future]: - # We should have _samples by now. if not self._samples: - raise RuntimeError('Loaded samples are not available.') - return self._launch_compilation_workers(pool, perturbations) + self.load_samples() + self.ensure_baselines(pool) + return self._launch_compilation_workers(pool, self._samples, perturbations) def set_baseline(self, pool: FixedWorkerPool) -> None: - if self._baselines is not None: - raise RuntimeError('The baseline has already been set.') - self._load_samples() - results_futures = self._launch_compilation_workers(pool) - self._baselines = _extract_results(results_futures) + pass def get_rewards( self, results_futures: list[concurrent.futures.Future]) -> list[float | None]: + # we need a pool to get the baselines, so we should have gotten them already if self._baselines is None: raise RuntimeError('The baseline has not been set.') @@ -165,6 +193,7 @@ def get_rewards( else: rewards.append( compilation_runner.calculate_reward(policy_result, baseline)) + self._reset() return rewards diff --git a/compiler_opt/es/blackbox_evaluator_test.py b/compiler_opt/es/blackbox_evaluator_test.py index 1d68d2c9..0c71bd0d 100644 --- a/compiler_opt/es/blackbox_evaluator_test.py +++ b/compiler_opt/es/blackbox_evaluator_test.py @@ -61,8 +61,8 @@ def test_sampling_set_baseline(self): train_corpus=test_corpus, estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD, total_num_perturbations=1) - - evaluator.set_baseline(pool) + evaluator.load_samples() + evaluator.ensure_baselines(pool) # pylint: disable=protected-access self.assertAlmostEqual(evaluator._baselines, [10]) @@ -90,7 +90,8 @@ def test_sampling_get_rewards_with_baseline(self): estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD, total_num_perturbations=2) - evaluator.set_baseline(pool) + evaluator.load_samples() + evaluator.ensure_baselines(pool) f_policy1 = concurrent.futures.Future() f_policy1.set_result(1.5)