-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcake_eating.py
More file actions
267 lines (198 loc) · 8.9 KB
/
cake_eating.py
File metadata and controls
267 lines (198 loc) · 8.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import os
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
DISCOUNT_FACTOR = 0.9
VALUE_FUNCTION_CONSTANT_TERM = (
np.log(1 - DISCOUNT_FACTOR)
+ np.log(DISCOUNT_FACTOR) * DISCOUNT_FACTOR / (1 - DISCOUNT_FACTOR)
) / (1 - DISCOUNT_FACTOR)
PLOT_DIR = "plots"
# These grids are used for naive searches over the action space
# action = fraction_consumed * wealth
FRACTION_CONSUMED_GRID_COARSE = np.linspace(0.01, 0.99, 4)
FRACTION_CONSUMED_GRID_FINE = np.linspace(0.01, 0.99, 8)
def reward_function(action):
# The agent's action is the amount they consume in the current period
# If they consume nothing, they receive a reward of negative infinity (they die!)
# If they consume everything, they receive a large reward in the current period, but they die tomorrow
# The optimal action must therefore be to consume something today (while also saving for tomorrow)
return np.log(action)
def optimal_policy(state):
# The agent's state variable is their wealth (their "amount of cake")
# They need to decide how much to consume today and how much to leave for future periods
# The agent lives forever and is unemployed: their wealth can never increase; it can only
# decrease depending on how much they consume. The optimal policy can be solved with
# pen and paper. A more patient agent (with a higher discount factor) consumes less today
# and saves more for future periods
return (1 - DISCOUNT_FACTOR) * state
def optimal_value_function(state):
# This is the value of pursuing the optimal policy, and it can be solved exactly with pen and paper
return (1 / (1 - DISCOUNT_FACTOR)) * np.log(state) + VALUE_FUNCTION_CONSTANT_TERM
def get_next_state(state, action):
# The agent's action is how much to consume today.
# Whatever is not consumed today is available tomorrow.
# Wealth is a stock, consumption is a flow.
return state - action
def optimal_policy_grid_search(
state, approximate_value_function, fraction_consumed_grid
):
# The fraction consumed is in [0, 1], and the action is equal to wealth * fraction_consumed
state_mesh, fraction_consumed_mesh = np.meshgrid(state, fraction_consumed_grid)
actions = state_mesh * fraction_consumed_mesh
rewards = reward_function(actions)
next_states = get_next_state(state_mesh, actions)
log_next_states = np.log(next_states.reshape(-1, 1))
continuation_values = approximate_value_function.predict(log_next_states).reshape(
actions.shape
)
candidate_values = rewards + DISCOUNT_FACTOR * continuation_values
argmax_candidate_values = np.argmax(candidate_values, axis=0)
return actions[argmax_candidate_values, range(state.size)]
def optimal_policy_given_approximate_value_function(state, approximate_value_function):
log_wealth_coefficient = approximate_value_function.coef_[0]
# On the first iteration, the coefficient on log wealth is zero (future wealth has no value),
# so, without this shortcut, the agent would consume everything immediately and get a -Inf continuation value
if log_wealth_coefficient <= 0.0:
return state * 0.99
# To arrive at this policy, write down the Bellman equation using the
# approximate value function as the continuation value, and optimize with respect to the action
return state / (DISCOUNT_FACTOR * log_wealth_coefficient + 1)
def get_estimated_values(
states, approximate_value_function, get_optimal_policy, **kwargs
):
actions = get_optimal_policy(states, approximate_value_function, **kwargs)
rewards = reward_function(actions)
next_states = get_next_state(states, actions)
# The approximated value function takes log(state) as input and returns an estimated value
log_next_states = np.log(next_states.reshape(-1, 1))
continuation_values = approximate_value_function.predict(log_next_states)
# This is the Bellman equation
return rewards + DISCOUNT_FACTOR * continuation_values
def get_coefficients(linear_regression):
return np.vstack([linear_regression.intercept_, linear_regression.coef_])
def calculate_approximate_solution(
get_optimal_policy, max_iterations=10000, n_simulations=2000, **kwargs
):
X = np.zeros((n_simulations, 1))
y = np.zeros((n_simulations,))
approximate_value_function = LinearRegression()
approximate_value_function.fit(X=X, y=y)
print(
f"running solver using {get_optimal_policy} to find actions given estimated value function"
)
for i in range(max_iterations):
states = np.random.uniform(low=0.001, high=5.0, size=n_simulations)
X[:, 0] = np.log(states)
estimated_values = get_estimated_values(
states, approximate_value_function, get_optimal_policy, **kwargs
)
y = estimated_values
previous_coefficients = get_coefficients(approximate_value_function)
approximate_value_function.fit(X=X, y=y)
current_coefficients = get_coefficients(approximate_value_function)
if np.allclose(
current_coefficients, previous_coefficients, rtol=1e-04, atol=1e-06
):
print(f"converged at iteration {i}!")
break
print(
f"true value is {(1 / (1 - DISCOUNT_FACTOR))}, estimate is {approximate_value_function.coef_}"
)
print(
f"true value is {VALUE_FUNCTION_CONSTANT_TERM}, estimate is {approximate_value_function.intercept_}"
)
return approximate_value_function
def save_value_function_plot(
approximate_value_function,
approximate_value_function_coarse_grid,
approximate_value_function_fine_grid,
):
fig, ax = plt.subplots(figsize=(10, 8))
# Note: wealth is the state variable
wealth = np.linspace(0.01, 100, 1000)
log_wealth = np.log(wealth).reshape(-1, 1)
correct_value = optimal_value_function(wealth)
approximate_value = approximate_value_function.predict(log_wealth)
approximate_value_coarse_grid = approximate_value_function_coarse_grid.predict(
log_wealth
)
approximate_value_fine_grid = approximate_value_function_fine_grid.predict(
log_wealth
)
plt.plot(
wealth,
correct_value,
label="true value function (analytical solution)",
linewidth=2,
)
# Note: don't show the left and right ends of the estimated value functions
# so that they don't entirely cover/hide the true value function on the plot
idx_start, idx_stop = (1, -10)
plt.plot(
wealth[idx_start:idx_stop],
approximate_value[idx_start:idx_stop],
"--",
label="estimated value function (using log-linear regression & first order condition for action)",
)
plt.plot(
wealth[idx_start:idx_stop],
approximate_value_coarse_grid[idx_start:idx_stop],
":",
label="estimated value function (using log-linear regression & coarse grid search for action)",
)
plt.plot(
wealth[idx_start:idx_stop],
approximate_value_fine_grid[idx_start:idx_stop],
":",
label="estimated value function (using log-linear regression & fine grid search for action)",
)
plt.xlabel("wealth (state variable)")
plt.ylabel("value function")
ax.legend()
outdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), PLOT_DIR)
outfile = "cake_eating_problem_value_function.png"
plt.savefig(os.path.join(outdir, outfile))
plt.close()
fig, ax = plt.subplots(figsize=(10, 8))
plt.plot(
wealth, optimal_policy(wealth), label="optimal action (analytical solution)"
)
plt.plot(
wealth,
optimal_policy_grid_search(
wealth,
approximate_value_function_coarse_grid,
FRACTION_CONSUMED_GRID_COARSE,
),
label="action (using log-linear regression & coarse grid search for action)",
)
plt.plot(
wealth,
optimal_policy_grid_search(
wealth, approximate_value_function_fine_grid, FRACTION_CONSUMED_GRID_FINE
),
label="action (using log-linear regression & fine grid search for action)",
)
plt.xlabel("wealth (state variable)")
plt.ylabel("amount consumed (action)")
ax.legend()
outfile = "cake_eating_problem_action.png"
plt.savefig(os.path.join(outdir, outfile))
def main():
approximate_value_function = calculate_approximate_solution(
optimal_policy_given_approximate_value_function
)
approximate_value_function_coarse_grid = calculate_approximate_solution(
optimal_policy_grid_search, fraction_consumed_grid=FRACTION_CONSUMED_GRID_COARSE
)
approximate_value_function_fine_grid = calculate_approximate_solution(
optimal_policy_grid_search, fraction_consumed_grid=FRACTION_CONSUMED_GRID_FINE
)
save_value_function_plot(
approximate_value_function,
approximate_value_function_coarse_grid,
approximate_value_function_fine_grid,
)
if __name__ == "__main__":
main()