Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 1 addition & 32 deletions compiler_opt/baseline_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,29 @@ def track_score(lst):
score_asked_for.extend(lst)
return [mock[k] if k in mock else None for k in lst]

cache = baseline_cache.BaselineCache(get_key=lambda x: x)
cache = baseline_cache.BaselineCache(get_key=lambda x: x)
self.assertEmpty(cache.get_cache())
self.assertEqual(
cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2])
self.assertEqual(
cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2])
self.assertDictEqual(cache.get_cache(), {"b": 2, "c": 3})
self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"]))
score_asked_for.clear()

self.assertEqual(
cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2])
self.assertEqual(
cache.get_score(["c", "b"], get_scores_func=track_score), [3, 2])
self.assertListEqual(score_asked_for, [])
self.assertEqual(
cache.get_score(["a", "c", "b"], get_scores_func=track_score),
[1, 3, 2])
self.assertEqual(
cache.get_score(["a", "c", "b"], get_scores_func=track_score),
[1, 3, 2])
self.assertListEqual(score_asked_for, ["a"])
score_asked_for.clear()

self.assertEqual(
cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score),
[1, None, 3, 2])
self.assertEqual(
cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score),
[1, None, 3, 2])
self.assertListEqual(score_asked_for, ["n"])
score_asked_for.clear()

self.assertEqual(
cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score),
[1, None, 3, 2])
self.assertEqual(
cache.get_score(["a", "n", "c", "b"], get_scores_func=track_score),
[1, None, 3, 2])
Expand All @@ -98,21 +84,6 @@ def track_score(lst):
[3, 2, 3, 2])
self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"]))

def test_duplicates(self):
mock = {"a": 1, "b": 2, "c": 3}
score_asked_for = []

def track_score(lst):
score_asked_for.extend(lst)
return [mock[k] if k in mock else None for k in lst]

cache = baseline_cache.BaselineCache(get_key=lambda x: x)
self.assertEmpty(cache.get_cache())
self.assertEqual(
cache.get_score(["c", "b", "c", "b"], get_scores_func=track_score),
[3, 2, 3, 2])
self.assertListEqual(sorted(score_asked_for), sorted(["c", "b"]))

def test_with_workers(self):
with local_worker_manager.LocalWorkerPoolManager(
worker_class=MockWorker, count=4) as lwm:
Expand All @@ -129,14 +100,12 @@ def get_scores(items: list[str]):
futures, return_when=concurrent.futures.ALL_COMPLETED)
return [f.result() if f.exception() is None else None for f in futures]

cache = baseline_cache.BaselineCache(get_key=lambda x: x)
cache = baseline_cache.BaselineCache(get_key=lambda x: x)
self.assertEmpty(cache.get_cache())
self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2])
self.assertListEqual(score_asked_for, ["4", "2"])
self.assertListEqual(sorted(score_asked_for), sorted(["4", "2"]))
self.assertDictEqual(cache.get_cache(), {"4": 4, "2": 2})
score_asked_for.clear()

self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2])
self.assertEqual(cache.get_score(["4", "2"], get_scores), [4, 2])
self.assertListEqual(score_asked_for, [])
67 changes: 48 additions & 19 deletions compiler_opt/es/blackbox_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from compiler_opt.es import blackbox_optimizers
from compiler_opt.distributed import buffered_scheduler
from compiler_opt.rl import compilation_runner
from compiler_opt import baseline_cache


def _extract_results(futures: list[concurrent.futures.Future]) -> list[Any]:
Expand All @@ -51,6 +52,8 @@ def __init__(self, *, train_corpus: corpus.Corpus,
estimator_type: blackbox_optimizers.EstimatorType):
self._train_corpus = train_corpus
self._estimator_type = estimator_type
self._baseline_cache = baseline_cache.BaselineCache(
get_key=lambda x: x.name)

@abc.abstractmethod
def get_results(
Expand All @@ -73,12 +76,17 @@ def __init__(self,
num_ir_repeats_within_worker: int = 1,
**kwargs):
super().__init__(**kwargs)
self._samples: list[list[corpus.LoadedModuleSpec]] = []
self._total_num_perturbations = total_num_perturbations
self._num_ir_repeats_within_worker = num_ir_repeats_within_worker
self._baselines: list[float | None] | None = None
self._reset()

def _load_samples(self) -> None:
def _reset(self):
# TODO: this object is currently supposed to respect a state transition
# and that makes it less maintainable than if not.
self._samples = None
self._baselines = None

def load_samples(self) -> None:
"""Samples and loads modules if not already done.

Ensures self._samples contains the expected number of loaded samples.
Expand All @@ -89,6 +97,7 @@ def _load_samples(self) -> None:
"""
if self._samples:
raise RuntimeError('Samples have already been loaded.')
self._samples = []
for _ in range(self._total_num_perturbations):
samples = self._train_corpus.sample(self._num_ir_repeats_within_worker)
loaded_samples = [
Expand All @@ -108,15 +117,15 @@ def _load_samples(self) -> None:
if len(self._samples) != expected_count:
raise RuntimeError('Some samples could not be loaded correctly.')

def _launch_compilation_workers(self,
pool: FixedWorkerPool,
perturbations: list[bytes] | None = None
) -> list[concurrent.futures.Future]:
if self._samples is None:
raise RuntimeError('Loaded samples are not available.')
def _launch_compilation_workers(
self,
pool: FixedWorkerPool,
samples: list[list[corpus.LoadedModuleSpec]],
perturbations: list[bytes] | None = None
) -> list[concurrent.futures.Future]:
if perturbations is None:
perturbations = [None] * len(self._samples)
compile_args = zip(perturbations, self._samples)
perturbations = [None] * len(samples)
compile_args = zip(perturbations, samples)
_, futures = buffered_scheduler.schedule_on_worker_pool(
action=lambda w, args: w.compile(policy=args[0], modules=args[1]),
jobs=compile_args,
Expand All @@ -130,24 +139,43 @@ def _launch_compilation_workers(self,
not_done, return_when=concurrent.futures.FIRST_COMPLETED)
return futures

def ensure_baselines(self, pool):
if self._samples is None:
raise RuntimeError('Loaded samples are not available.')
# flatten the samples.
flat_samples = [item for sublist in self._samples for item in sublist]

def _get_scores(some_list):
futures = self._launch_compilation_workers(pool, [[x] for x in some_list])
return _extract_results(futures)

baselines = self._baseline_cache.get_score(flat_samples, _get_scores)

# TODO: the business of accummulating compilation results is now shared
# with the worker.
def sum_or_none(lst):
return sum(lst) if all(x is not None for x in lst) else None

self._baselines = [
sum_or_none(baselines[i:i + len(self._samples[i])])
for i in range(len(self._samples))
]

def get_results(
self, pool: FixedWorkerPool,
perturbations: list[bytes]) -> list[concurrent.futures.Future]:
# We should have _samples by now.
if not self._samples:
raise RuntimeError('Loaded samples are not available.')
return self._launch_compilation_workers(pool, perturbations)
self.load_samples()
self.ensure_baselines(pool)
return self._launch_compilation_workers(pool, self._samples, perturbations)

def set_baseline(self, pool: FixedWorkerPool) -> None:
if self._baselines is not None:
raise RuntimeError('The baseline has already been set.')
self._load_samples()
results_futures = self._launch_compilation_workers(pool)
self._baselines = _extract_results(results_futures)
pass

def get_rewards(
self,
results_futures: list[concurrent.futures.Future]) -> list[float | None]:
# we need a pool to get the baselines, so we should have gotten them already
if self._baselines is None:
raise RuntimeError('The baseline has not been set.')

Expand All @@ -165,6 +193,7 @@ def get_rewards(
else:
rewards.append(
compilation_runner.calculate_reward(policy_result, baseline))
self._reset()
return rewards


Expand Down
7 changes: 4 additions & 3 deletions compiler_opt/es/blackbox_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def test_sampling_set_baseline(self):
train_corpus=test_corpus,
estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD,
total_num_perturbations=1)

evaluator.set_baseline(pool)
evaluator.load_samples()
evaluator.ensure_baselines(pool)
# pylint: disable=protected-access
self.assertAlmostEqual(evaluator._baselines, [10])

Expand Down Expand Up @@ -90,7 +90,8 @@ def test_sampling_get_rewards_with_baseline(self):
estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD,
total_num_perturbations=2)

evaluator.set_baseline(pool)
evaluator.load_samples()
evaluator.ensure_baselines(pool)

f_policy1 = concurrent.futures.Future()
f_policy1.set_result(1.5)
Expand Down