-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample_usage.py
More file actions
149 lines (125 loc) · 4.34 KB
/
example_usage.py
File metadata and controls
149 lines (125 loc) · 4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
"""
Generate minimal example figures for MARL Games documentation.
"""
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
import os
from marl_games import MatrixGame, DiGrid, QL
def setup():
os.makedirs('figures', exist_ok=True)
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'figure.facecolor': 'white', 'axes.facecolor': 'white'})
def fig1_summary():
"""1. Epsilon-Greedy Summary Plot"""
print("Generating Summary Plot...")
# User-specified settings
payoff_matrix = np.array([
[[2, 3], [4, 1]],
[[3, 1], [2, 4]]
], dtype=np.float64)
num_iterations = 400
numExperiments = 100
init_q_values = np.array([[0, 1], [2, 3]], np.float64)
qLogList = QL.QLogList()
for i in range(numExperiments):
qLog = QL.epsilon_greedy_q_learning(
payoff=payoff_matrix,
num_iterations=num_iterations,
init_q_values=init_q_values.copy(), # Ensure copy to avoid mutation
alpha=0.1,
epsilon=0.1
)
qLogList.append(qLog)
qLog = qLogList.median()
QL.QPlot.summary(qLog)
plt.tight_layout()
plt.savefig('figures/1_summary.png', dpi=150, bbox_inches='tight')
plt.close()
def fig2_field():
"""2. Vector Field Plot"""
print("Generating Vector Field Plot...")
payoff = MatrixGame.STAG_HUNT
X, Y, DX, DY = DiGrid.replicator_dynamics(payoff, grid_shape=(12, 12))
fig, ax = plt.subplots(figsize=(6, 6))
ax.quiver(X, Y, DX, DY, color='steelblue')
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
ax.set_title("Replicator Dynamics (Stag Hunt)")
ax.set_aspect('equal')
plt.tight_layout()
plt.savefig('figures/2_field.png', dpi=150, bbox_inches='tight')
plt.close()
def fig3_trace():
"""3. Learning Trace Comparison"""
print("Generating Trace Plot...")
payoff = MatrixGame.STAG_HUNT
fig, ax = plt.subplots(figsize=(6, 6))
start1 = MatrixGame.generate_policy(0.1, 0.9)
init_q1 = QL.generate_mean_q_values(payoff, 0.1, start1)
qLog1 = QL.boltzmann_q_learning(
payoff=payoff,
num_iterations=1000,
alpha=0.001,
temperature=0.1,
init_q_values=init_q1
)
QL.QPlot.trace(ax, qLog1) # Default color
start2 = MatrixGame.generate_policy(0.9, 0.1)
init_q2 = QL.generate_mean_q_values(payoff, 0.1, start2)
logList = QL.QLogList()
for _ in range(32):
log = QL.boltzmann_q_learning(
payoff=payoff,
num_iterations=1000,
alpha=0.001,
temperature=0.1,
init_q_values=init_q2.copy()
)
logList.append(log)
meanLog = logList.mean()
QL.QPlot.trace(ax, meanLog)
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
ax.set_title("Trace Comparison: Speed vs Smoothness")
ax.set_aspect('equal')
ax.set_xlabel('P1 (Prob Action 0)')
ax.set_ylabel('P2 (Prob Action 0)')
plt.tight_layout()
plt.savefig('figures/3_trace.png', dpi=150, bbox_inches='tight')
plt.close()
def fig4_combined():
"""4. Vector Field + Traces (9 Grid Points)"""
print("Generating Combined Plot...")
payoff = MatrixGame.STAG_HUNT
# Background: Dynamics (Boltzmann to match Q-learning)
X, Y, DX, DY = DiGrid.boltzmann_replicator_dynamics(
payoff, grid_shape=(12, 12), temperature=0.1
)
fig, ax = plt.subplots(figsize=(6, 6))
ax.quiver(X, Y, DX, DY, color='lightgray', alpha=0.5)
# Foreground: 9 Traces using generate_policy_grid
# Grid shape (3, 3) gives 9 points
policy_grid = MatrixGame.generate_policy_grid((3, 3))
for init_policy in policy_grid:
init_q = QL.generate_mean_q_values(payoff, 0.1, init_policy)
qLog = QL.boltzmann_q_learning(
payoff=payoff,
num_iterations=1000,
alpha=0.001,
temperature=0.1,
init_q_values=init_q
)
QL.QPlot.trace(ax, qLog)
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
ax.set_title("Dynamics + 9 Learning Traces")
ax.set_aspect('equal')
plt.tight_layout()
plt.savefig('figures/4_combined.png', dpi=150, bbox_inches='tight')
plt.close()
if __name__ == '__main__':
setup()
fig1_summary()
fig2_field()
fig3_trace()
fig4_combined()
print("Done.")