Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[run]
omit = baybe/utils/plotting.py
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
Comment thread
Scienfitz marked this conversation as resolved.
Comment thread
AVHopp marked this conversation as resolved.
### Added
- Subpackages for the available recommender types
- Multi-style plotting capabilities for generated example plots
- JSON file for plotting themes
- Smoke testing in relevant tox environments

### Changed
- `Recommender`s now share their core logic via their base class
Expand Down
134 changes: 134 additions & 0 deletions baybe/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Plotting utilities."""

import json
import os
import sys
import warnings
from pathlib import Path
from typing import Any, Dict, Tuple

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure


def create_example_plots(
ax: Axes,
path: Path,
base_name: str,
) -> None:
"""Create plots from an Axes object and save them as a svg file.

Comment thread
AVHopp marked this conversation as resolved.
The plots will be saved in the location specified by ``path``.
The attribute ``base_name`` is used to define the name of the outputs.

If the ``SMOKE_TEST`` variable is set, no plots are being created and this method
immediately returns.

The function attempts to read the predefined themes from ``plotting_themes.json``.
For each theme it finds, a file ``{base_name}_{theme}.svg`` is being created.
If the file cannot be found, if the JSON cannot be loaded or if the JSON is not well
configured, a fallback theme is used.

Args:
ax: The Axes object containing the figure that should be plotted.
path: The path to the directory in which the plots should be saved.
base_name: The base name that is used for naming the output files.
"""
# Check whether we immediately return due to just running a SMOKE_TEST
if "SMOKE_TEST" in os.environ:
return

# Define a fallback theme in case no configuration is found
fallback: Dict[str, Any] = {
"color": "black",
"figsize": (24, 8),
"fontsize": 22,
"framealpha": 0.3,
}

# Try to find the plotting themes by backtracking
# Get the absolute path of the current script
script_path = Path(sys.path[0]).resolve()
while (
not Path(script_path / "plotting_themes.json").is_file()
and script_path != script_path.parent
):
script_path = script_path.parent
if script_path == script_path.parent:
warnings.warn("No themes for plotting found. A fallback theme is used.")
themes = {"fallback": fallback}
else:
# Open the file containing all the themes
# If we reach this point, we know that the file exists, so we try to load it.
# If the file is no proper json, the fallback theme is used.
try:
themes = json.load(open(script_path / "plotting_themes.json"))
except json.JSONDecodeError:
warnings.warn(
"The JSON containing the themes could not be loaded."
"A fallback theme is used.",
UserWarning,
)
themes = {"fallback": fallback}

for theme_name in themes:
# Get all of the values from the themes
# TODO This can probably be generalized and improved later on such that the
# keys fit the rc_params of matplotlib
# TODO We might want to add a generalization here
necessary_keys = ("color", "figsize", "fontsize", "framealpha")
if not all(key in themes[theme_name] for key in necessary_keys):
warnings.warn(
"Provided theme does not contain the necessary keys."
"Using a fallback theme instead.",
UserWarning,
)
current_theme = fallback
else:
current_theme = themes[theme_name]
color: str = current_theme["color"]
figsize: Tuple[int, int] = current_theme["figsize"]
fontsize: int = current_theme["fontsize"]
framealpha: float = current_theme["framealpha"]

# Adjust the axes of the plot
for key in ax.spines.keys():
ax.spines[key].set_color(color)
ax.xaxis.label.set_color(color)
ax.xaxis.label.set_fontsize(fontsize)
ax.yaxis.label.set_color(color)
ax.yaxis.label.set_fontsize(fontsize)

# Adjust the size of the ax
# mypy thinks that ax.figure might become None, hence the explicit ignore
if isinstance(ax.figure, Figure):
ax.figure.set_size_inches(*figsize)
else:
warnings.warn("Could not adjust size of plot due to it not being a Figure.")

# Adjust the labels
for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_color(color)
label.set_fontsize(fontsize)

# Adjust the legend
legend = ax.get_legend()
legend.get_frame().set_alpha(framealpha)
legend.get_title().set_color(color)
legend.get_title().set_fontsize(fontsize)
for text in legend.get_texts():
text.set_fontsize(fontsize)
text.set_color(color)

output_path = Path(path, f"{base_name}_{theme_name}.svg")
# mypy thinks that ax.figure might become None, hence the explicit ignore
if isinstance(ax.figure, Figure):
ax.figure.savefig(
output_path,
format="svg",
transparent=True,
)
else:
warnings.warn("Plots could not be saved.")
plt.close()
20 changes: 19 additions & 1 deletion docs/scripts/utils.py
Comment thread
AVHopp marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def create_example_documentation(example_dest_dir: str, ignore_examples: bool):
# Include the name of the file to the toctree
# Format it by replacing underscores and capitalizing the words
file_name = file.stem

formatted = " ".join(word.capitalize() for word in file_name.split("_"))
# Remove duplicate "constraints" for the files in the constraints folder.
if "Constraints" in folder_name and "Constraints" in formatted:
Expand Down Expand Up @@ -162,7 +163,24 @@ def create_example_documentation(example_dest_dir: str, ignore_examples: bool):
with open(markdown_path, "r", encoding="UTF-8") as markdown_file:
lines = markdown_file.readlines()

# Delete lines we do not want to have in our documentation
lines = [line for line in lines if "![svg]" not in line]
lines = [line for line in lines if "![png]" not in line]
lines = [line for line in lines if "<Figure size" not in line]

# We check whether pre-built light and dark plots exist. If so, we append
# corresponding lines to our markdown file for including them.
light_figure = pathlib.Path(sub_directory / (file_name + "_light.svg"))
dark_figure = pathlib.Path(sub_directory / (file_name + "_dark.svg"))
if light_figure.is_file() and dark_figure.is_file():
lines.append(f"```{{image}} {file_name}_light.svg\n")
lines.append(":align: center\n")
lines.append(":class: only-light\n")
lines.append("```\n")
lines.append(f"```{{image}} {file_name}_dark.svg\n")
lines.append(":align: center\n")
lines.append(":class: only-dark\n")
lines.append("```\n")

# Rewrite the file
with open(markdown_path, "w", encoding="UTF-8") as markdown_file:
Expand All @@ -188,7 +206,7 @@ def create_example_documentation(example_dest_dir: str, ignore_examples: bool):
# Remove any not markdown files
for file in examples_directory.glob("**/*"):
if file.is_file():
if file.suffix != ".md" or "Header" in file.name:
if file.suffix not in (".md", ".svg") or "Header" in file.name:
file.unlink(file)

# Remove any remaining empty subdirectories
Expand Down
71 changes: 35 additions & 36 deletions examples/Backtesting/botorch_analytical.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
## Example for full simulation loop using a BoTorch test function
## Simulation loop using a BoTorch test function

# This example shows a simulation loop for a single target with a BoTorch test function as lookup.
# That is, we perform several Monte Carlo runs with several iterations.
# In addition, we also store and display the results.

# This example assumes some basic familiarity with using BayBE and how to use BoTorch test
# functions in discrete searchspaces.
Expand All @@ -11,30 +9,37 @@
# 2. [`discrete_space`](./../Searchspaces/discrete_space.md) for details on using a
# BoTorch test function.

### Necessary imports for this example
### Imports

import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from botorch.test_functions import Rastrigin

from baybe import Campaign
from baybe.objective import Objective
from baybe.parameters import NumericalDiscreteParameter
from baybe.recommenders import RandomRecommender, SequentialGreedyRecommender
from baybe.recommenders import RandomRecommender
from baybe.searchspace import SearchSpace
from baybe.simulation import simulate_scenarios
from baybe.strategies import TwoPhaseStrategy
from baybe.targets import NumericalTarget
from baybe.utils.botorch_wrapper import botorch_function_wrapper
from baybe.utils.plotting import create_example_plots

### Parameters for a full simulation loop

# For the full simulation, we need to define some additional parameters.
# These are the number of Monte Carlo runs and the number of experiments to be conducted per run.
# For the full simulation, we need to define the number of Monte Carlo runs
# and the number of experiments to be conducted per run.

SMOKE_TEST = "SMOKE_TEST" in os.environ

N_MC_ITERATIONS = 2
N_DOE_ITERATIONS = 2
N_MC_ITERATIONS = 2 if SMOKE_TEST else 30
N_DOE_ITERATIONS = 2 if SMOKE_TEST else 15
BATCH_SIZE = 1 if SMOKE_TEST else 3
Comment thread
AVHopp marked this conversation as resolved.
POINTS_PER_DIM = 10

### Defining the test function

Expand All @@ -59,10 +64,6 @@

### Creating the searchspace and the objective

# The parameter `POINTS_PER_DIM` controls the number of points per dimension.
# Note that the searchspace will have `POINTS_PER_DIM**DIMENSION` many points.

POINTS_PER_DIM = 10
parameters = [
NumericalDiscreteParameter(
name=f"x_{k+1}",
Expand All @@ -83,34 +84,21 @@
mode="SINGLE", targets=[NumericalTarget(name="Target", mode="MIN")]
)

### Constructing campaigns for the simulation loop

# To simplify adjusting the example for other strategies, we construct some strategy objects.
# For details on strategy objects, we refer to [`strategies`](./../Basics/strategies.md).

seq_greedy_EI_strategy = TwoPhaseStrategy(
recommender=SequentialGreedyRecommender(acquisition_function_cls="qEI"),
)
random_strategy = TwoPhaseStrategy(recommender=RandomRecommender())

# We now create one campaign per strategy.
### Constructing campaigns

seq_greedy_EI_campaign = Campaign(
searchspace=searchspace,
strategy=seq_greedy_EI_strategy,
objective=objective,
)
random_campaign = Campaign(
searchspace=searchspace,
strategy=random_strategy,
strategy=RandomRecommender(),
objective=objective,
)

### Performing the simulation loop

# We can now use the `simulate_scenarios` function to simulate a full experiment.
# Note that this function enables to run multiple scenarios by a single function call.
# For this, it is necessary to define a dictionary mapping scenario names to campaigns.
# We use [simulate_scenarios](baybe.simulation.simulate_scenarios) to simulate a full experiment.

scenarios = {
"Sequential greedy EI": seq_greedy_EI_campaign,
Expand All @@ -119,13 +107,24 @@
results = simulate_scenarios(
scenarios,
WRAPPED_FUNCTION,
batch_size=3,
batch_size=BATCH_SIZE,
n_doe_iterations=N_DOE_ITERATIONS,
n_mc_iterations=N_MC_ITERATIONS,
)

# The following lines plot the results and save the plot in run_analytical.png
# We use the plotting utility to create plots.

sns.lineplot(data=results, x="Num_Experiments", y="Target_CumBest", hue="Scenario")
plt.gcf().set_size_inches(24, 8)
plt.savefig("./run_analytical.png")
path = Path(sys.path[0])
ax = sns.lineplot(
data=results,
marker="o",
markersize=10,
x="Num_Experiments",
y="Target_CumBest",
hue="Scenario",
)
create_example_plots(
ax=ax,
path=path,
base_name="botorch_analytical",
)
Loading