-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmobo_assignment_test.py
More file actions
153 lines (116 loc) · 4.94 KB
/
mobo_assignment_test.py
File metadata and controls
153 lines (116 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import warnings
from utils import set_seeds, measure_epoxy_old
import numpy as np
from ax.service.ax_client import AxClient, ObjectiveProperties
import matplotlib.pyplot as plt
from ax.modelbridge.cross_validation import cross_validate, compute_diagnostics
from ax.core.observation import ObservationFeatures
import pytest
@pytest.fixture(scope="session")
def get_namespace():
script_fname = "mobo_assignment.py"
script_content = open(script_fname).read()
namespace = {}
exec(script_content, namespace)
required_vars = ["ax_client", "ax_client_full", "ax_client_thresh", "num_pareto_optimal_thresh", "num_pareto_sustainable_thresh", "num_pareto_optimal", "num_pareto_sustainable", "max_strength", "max_glass_t", "tradeoff"]
missing_vars = [var for var in required_vars if var not in namespace]
if missing_vars:
pytest.skip(f"Assignment incomplete. Missing variables: {missing_vars}")
return namespace
def test_task_a(get_namespace):
running_ax_client = get_namespace["ax_client_full"]
user_op_params = running_ax_client.experiment.parameters
# assert that len op_params is 5
assert len(user_op_params) == 5, "Expected 5 parameters, got {}".format(
len(user_op_params)
)
# assert that op_params contains ['time', 'temperature', 'v_prct', 'process']
assert all(
[param in ["EA", "EB", "EC", "AA", "AB"] for param in user_op_params]
), "Expected parameters named ['EA', 'EB'', 'EC', 'AA', 'AB''], got {}".format(
user_op_params.keys()
)
# assert that the ax_client budget is 40
assert (
len(running_ax_client.get_trials_data_frame()) == 40
), "Expected optimization budget of 40 trials, got {}".format(
len(running_ax_client.get_trials_data_frame())
)
def test_task_b(get_namespace):
user_max_strength = get_namespace["max_strength"]
user_max_glass_t = get_namespace["max_glass_t"]
# assert that max_strength is greater than 125
assert user_max_strength > 125, "Expected max_strength > 125, got {}".format(
user_max_strength
)
# assert that max_glass_t is greater than 110
assert user_max_glass_t > 110, "Expected max_glass_t > 110, got {}".format(
user_max_glass_t
)
def test_task_c(get_namespace):
user_pareto_optimal = get_namespace["num_pareto_optimal"]
user_pareto_sustainable = get_namespace["num_pareto_sustainable"]
# assert that num_pareto_optimal is 15
assert user_pareto_optimal == 15, "Expected num_pareto_optimal: 15, got {}".format(
user_pareto_optimal
)
# assert that num_pareto_sustainable is 6
assert (
user_pareto_sustainable == 6
), "Expected num_pareto_sustainable: 6, got {}".format(user_pareto_sustainable)
def test_task_d(get_namespace):
running_ax_client = get_namespace["ax_client_thresh"]
user_op_params = running_ax_client.experiment.parameters
# assert that len op_params is 4
assert len(user_op_params) == 5, "Expected 5 parameters, got {}".format(
len(user_op_params)
)
# assert that op_params contains ['time', 'temperature', 'v_prct', 'process']
assert all(
[param in ["EA", "EB", "EC", "AA", "AB"] for param in user_op_params]
), "Expected parameters named ['EA', 'EB'', 'EC', 'AA', 'AB''], got {}".format(
user_op_params.keys()
)
# assert that the ax_client budget is 25
assert (
len(running_ax_client.get_trials_data_frame()) == 40
), "Expected optimization budget of 40 trials, got {}".format(
len(running_ax_client.get_trials_data_frame())
)
user_obj_thresh_1 = (
running_ax_client.experiment.optimization_config.objective_thresholds[0].bound
)
user_obj_thresh_2 = (
running_ax_client.experiment.optimization_config.objective_thresholds[1].bound
)
# assert that the user obj thresholds is one of 75, 85
assert user_obj_thresh_1 in [
75,
85,
], "Expected objective threshold: 75 or 85, got {}".format(user_obj_thresh_1)
assert user_obj_thresh_2 in [
75,
85,
], "Expected objective threshold: 75 or 85, got {}".format(user_obj_thresh_2)
def test_task_e(get_namespace):
user_pareto_optimal_thresh = get_namespace["num_pareto_optimal_thresh"]
user_pareto_sustainable_thresh = get_namespace["num_pareto_sustainable_thresh"]
# assert that num_pareto_optimal is 15
assert (
user_pareto_optimal_thresh >= 12
), "Expected num_pareto_optimal_thresh: 12, got {}".format(
user_pareto_optimal_thresh
)
# assert that num_pareto_sustainable is 6
assert (
user_pareto_sustainable_thresh >= 12
), "Expected num_pareto_sustainable_thresh: 12, got {}".format(
user_pareto_sustainable_thresh
)
def test_task_f(get_namespace):
user_tradeoff = get_namespace["tradeoff"]
# assert that tradeoff is greater than 0.5
assert user_tradeoff < -0.5, "Expected tradeoff < -0.5, got {}".format(
user_tradeoff
)