-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_bandit_classification.py
More file actions
457 lines (354 loc) · 18.6 KB
/
run_bandit_classification.py
File metadata and controls
457 lines (354 loc) · 18.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
import os
import re
import argparse
import pandas as pd
import numpy as np
from typing import Optional
from tqdm import tqdm
from dataclasses import dataclass
from src.bandit import get_classification_bandit, ClassificationBandit, ButtonsBandit
from src.bandit_algorithms import UCB1_Algorithm
from src.bayesian_optimisation import new_candidate
from src.utils import BanditClassificationUtils, calculate_entropy, calculate_kl_divergence, calculate_discrete_mean, calculate_discrete_variance, calculate_min_Va_by_KL_rank
from src.prompt import BanditClassificationPrompt
from src.chat import chat
pd.set_option('display.max_columns', None)
parser = argparse.ArgumentParser(description='Running Toy Classification')
"""LLM API Configuration"""
parser.add_argument("--model_name", default="Qwen/Qwen2.5-14B", type=str)
parser.add_argument("--model_port", default="8000", type=str)
parser.add_argument("--model_ip", default="localhost", type=str)
parser.add_argument("--model_temperature", default=1, type=float)
parser.add_argument("--is_local_client", default=1, type=int)
"""Bandit Configuration"""
parser.add_argument("--bandit_name", default="buttons", type=str)
parser.add_argument("--bandit_num_arms", default=5, type=int)
parser.add_argument("--bandit_midpoint", default=0.5, type=float)
parser.add_argument("--bandit_gap", default=0.2, type=float)
parser.add_argument("--bandit_seed", default=0, type=int)
parser.add_argument("--bandit_exploration_rate", default=2, type=float)
parser.add_argument("--is_contextual_bandit", default=0, type=int)
"""Experiment Configuration"""
parser.add_argument("--num_trials", default=10, type=int)
parser.add_argument("--num_random_trials", default=3, type=int)
parser.add_argument("--uncertainty_type", default="epistemic", type=str)
"""Permutation Related Configuration"""
parser.add_argument("--num_permutations", default=10, type=int)
parser.add_argument("--permute_context", default=1, type=int)
"""Seed Configuration"""
parser.add_argument("--numpy_seed", default=0, type=int)
parser.add_argument("--fixed_permutation_seed", default=0, type=int)
"""Z Configuration"""
parser.add_argument("--num_z", default=1, type=int)
parser.add_argument("--perturbation_std", default=1.0, type=float)
parser.add_argument("--decimal_places", default=3, type=int)
parser.add_argument("--min_KL_rank", default=1, type=int)
"""Save Configuration"""
parser.add_argument("--run_name", default="test")
parser.add_argument("--save_directory", default="other")
parser.add_argument("--num_api_calls_save_value", default=0, type=int)
parser.add_argument("--verbose_output", default=0, type=int)
args = parser.parse_args()
@dataclass
class BanditClassificationExperimentConfig:
model_name: str
model_port: str
model_ip: str
model_temperature: float
is_local_client: int
bandit_name: str
bandit_num_arms: int
bandit_midpoint: float
bandit_gap: float
bandit_seed: int
bandit_exploration_rate: float
is_contextual_bandit: int
numpy_seed: int
fixed_permutation_seed: int
num_trials: int
num_random_trials: int
uncertainty_type: str
num_permutations: int
permute_context: int
num_z: int
perturbation_std: float
decimal_places: int
min_KL_rank: int
run_name: int
save_directory: int
num_api_calls_save_value: int
verbose_output: int
class BanditClassificationExperiment:
def __init__(self, config: BanditClassificationExperimentConfig):
self.config = config
self.rng = np.random.default_rng(self.config.numpy_seed)
self.prompter = BanditClassificationPrompt()
self.create_bandit()
self.num_api_calls = self.config.num_api_calls_save_value
if self.config.uncertainty_type == "ucb1":
self.UCB1_algorithm = UCB1_Algorithm(num_arms=self.config.bandit_num_arms, c=self.config.bandit_exploration_rate)
else:
self.UCB1_algorithm = None
def create_bandit(self):
self.bandit: ClassificationBandit = get_classification_bandit(
bandit_name=self.config.bandit_name,
num_arms=self.config.bandit_num_arms,
gap=self.config.bandit_gap,
midpoint=self.config.bandit_midpoint,
seed=self.config.bandit_seed,
)
self.label_keys = self.bandit.get_reward_space()
self.action_space = self.bandit.get_action_space()
if isinstance(self.bandit, ButtonsBandit):
print(f"Best arm: {self.bandit.best_arm}")
if self.config.is_contextual_bandit:
self.feature_columns = self.bandit.get_context_feature_cols()
print("Features:", self.feature_columns)
else:
self.feature_columns = []
self.num_trials = self.config.num_trials
self.D_rows: pd.DataFrame = None
@property
def D_feature_stds(self):
self._D_feature_stds = self.D_rows[self.feature_columns].std().to_numpy().flatten()
return self._D_feature_stds
@property
def D_note_label_df(self):
return self.D_rows[['note', 'label']]
def save_D_rows(self):
if not os.path.exists(f"results/bandits/{self.config.bandit_name}/{self.config.save_directory}"):
os.makedirs(f"results/bandits/{self.config.bandit_name}/{self.config.save_directory}")
self.D_rows.to_csv(f"results/bandits/{self.config.bandit_name}/{self.config.save_directory}/D_{self.config.run_name}.csv", index=False)
def calculate_avg_probs(
self,
query_note: str,
probability_calculated: str,
icl_z_note: Optional[str]=None,
icl_u_label: Optional[str|int]=None,
):
# Initialize p(y|x)
avg_probs = {label: 0.0 for label in self.label_keys}
# ----- Processing p(y|x) -----
successful_seeds = 0
for seed in range(self.config.num_permutations):
# p(y|x)
if self.config.verbose_output:
print(f"\n{probability_calculated} Seed {seed + 1}/{self.config.num_permutations}")
permutation_seed = self.num_api_calls
try:
prompt = self.prompter.get_general_prompt(
D_df=self.D_note_label_df,
query_note=query_note,
permutation_seed=permutation_seed if self.config.permute_context else self.config.fixed_permutation_seed,
icl_z_note=icl_z_note,
icl_u_label=icl_u_label,
)
if self.config.verbose_output:
print(f"Prompt for {probability_calculated}:")
print(prompt)
# Get the prediction and probabilities from the model
pred, probs = chat(
prompt,
self.label_keys,
seed=permutation_seed,
model=self.config.model_name,
port=self.config.model_port,
ip=self.config.model_ip,
temperature=self.config.model_temperature,
is_local_client=self.config.is_local_client,
)
self.num_api_calls += 1
# Accumulate probabilities
for label, prob in probs.items():
avg_probs[label] += prob
successful_seeds += 1
except:
print(f"Seed {seed + 1} failed.")
avg_probs = {label: prob / successful_seeds for label, prob in avg_probs.items()}
if self.config.verbose_output:
print(f"\nAveraged {probability_calculated} probabilities: {avg_probs}")
return avg_probs
def get_random_action(self):
action = self.rng.choice(self.action_space)
return action
def get_next_z(self, z_idx: int, context: Optional[dict] = None, action: Optional[str|int] = None):
if action is None:
action = self.get_random_action()
for _ in range(100):
if self.config.is_contextual_bandit:
new_value = self.rng.normal(
np.array([float(x) for x in list(context.values())]),
self.config.perturbation_std * self.D_feature_stds,
len(self.feature_columns)
)
new_value = np.round(new_value, self.config.decimal_places)
if not any(np.array_equal(new_value, previous_z_value) for previous_z_value in self.previous_z_values):
self.previous_z_values.append(new_value)
break
if z_idx == 0:
dict_data = {}
if self.config.is_contextual_bandit:
dict_data = {feature_column: new_value[i] for i, feature_column in enumerate(self.feature_columns)}
dict_data.update({"action": action})
self.z_data = pd.DataFrame([dict_data])
if self.config.is_contextual_bandit:
self.z_data["note"] = self.z_data.apply(lambda row: BanditClassificationUtils.parse_features_and_action_to_note(row=row, feature_columns=self.feature_columns, action=action, decimal_places=self.config.decimal_places), axis=1)
else:
self.z_data["note"] = BanditClassificationUtils.parse_features_and_action_to_note(action=action, decimal_places=self.config.decimal_places)
else:
modified_row = self.z_data.loc[z_idx-1].copy()
modified_row[self.feature_columns] = new_value
modified_row["action"] = action
modified_row["note"] = BanditClassificationUtils.parse_features_and_action_to_note(action=action, row=modified_row, feature_columns=self.feature_columns, decimal_places=self.config.decimal_places)
self.z_data.loc[z_idx] = modified_row
def process_single_trial_action(self, trial: int, action: str|int, context: Optional[pd.Series] = None):
self.previous_z_values = []
x = BanditClassificationUtils.parse_features_and_action_to_note(action=action, row=context, feature_columns=self.feature_columns, decimal_places=self.config.decimal_places)
# Compute p(y|x,D)
avg_pyx_probs = self.calculate_avg_probs(x, "p(y|x,D)")
total_variance = calculate_discrete_variance(avg_pyx_probs)
mean_y = calculate_discrete_mean(avg_pyx_probs)
save_dict_list = []
for i in range(self.config.num_z):
self.get_next_z(z_idx=i, context=context, action=action)
row = self.z_data.iloc[i]
z = row['note']
if self.config.verbose_output:
print(f"z: {z}")
# Compute p(u|z,D)
avg_puz_probs = self.calculate_avg_probs(z, "p(u|z,D)")
# Compute p(y|x,u,z,D)
avg_pyxu_z_probs = {}
for outer_label in self.label_keys:
probability_calculated = f"p(y|x,u={outer_label},z,D)"
avg_probs_for_outer_label = self.calculate_avg_probs(
query_note=x,
probability_calculated=probability_calculated,
icl_z_note=z,
icl_u_label=outer_label
)
avg_pyxu_z_probs.update({probability_calculated: avg_probs_for_outer_label})
# Marginalisation
avg_pyxz_probs = {}
for label in self.label_keys: # Iterate over all possible values of y
avg_pyxz_probs[label] = sum(
avg_pyxu_z_probs[f"p(y|x,u={u_label},z,D)"][label] * avg_puz_probs[u_label]
for u_label in self.label_keys
)
# Variance
Var_uz = calculate_discrete_variance(avg_puz_probs)
Var_yxuz = {f"Var[{key}]": calculate_discrete_variance(value) for key, value in avg_pyxu_z_probs.items()}
E_Var_yxuz = 0.0
for label in self.label_keys:
E_Var_yxuz += Var_yxuz[f"Var[p(y|x,u={label},z,D)]"]*avg_puz_probs[label]
Va_variance = np.round(E_Var_yxuz, 5)
Ve_variance = total_variance - Va_variance
# KL Divergence
kl_pyx_pyxz = calculate_kl_divergence(avg_pyx_probs, avg_pyxz_probs)
kl_pyxz_pyx = calculate_kl_divergence(avg_pyxz_probs, avg_pyx_probs)
# Save
save_dict = {f"z_{feature}": row[feature] for feature in self.feature_columns}
save_dict["z_note"] = z
save_dict_x = {f"x_{feature}": context[feature] for feature in self.feature_columns}
save_dict_x["x_note"] = x
save_dict = {**save_dict, **save_dict_x}
for label, prob in avg_pyx_probs.items():
save_dict[f"p(y={label}|x,D)"] = prob
for label, prob in avg_puz_probs.items():
save_dict[f"p(u={label}|z,D)"] = prob
for key, outer_label_probs in avg_pyxu_z_probs.items():
for label, prob in outer_label_probs.items():
new_key = re.sub(r'y', f'y={label}', key, count=1)
save_dict[new_key] = prob
for label, prob in avg_pyxz_probs.items():
save_dict[f"p(y={label}|x,z,D)"] = prob
save_dict["Var[u|z,D]"] = Var_uz
for key, variance in Var_yxuz.items():
save_dict[key] = variance
save_dict["Var[y|x,D]"] = total_variance
save_dict["Va_variance"] = Va_variance
save_dict["Ve_variance"] = Ve_variance
save_dict["E[y|x,D]"] = mean_y
save_dict["kl_pyx_pyxz"] = kl_pyx_pyxz
save_dict["kl_pyxz_pyx"] = kl_pyxz_pyx
save_dict["api_calls"] = self.num_api_calls
save_dict_list.append(save_dict)
save_df = pd.DataFrame(save_dict_list)
save_df = calculate_min_Va_by_KL_rank(save_df, self.config.min_KL_rank, upper_bound_by_total_U=True, uncertainty_type="variance")
return save_df
def get_single_trial_action(self, trial: int, context: Optional[pd.Series] = None, random_action: bool = False, uncertainty_type: str = "epistemic"):
if random_action:
action_taken = self.get_random_action()
else:
Q_values = {}
U_values = {}
UCB_values = {}
for action in tqdm(self.action_space):
save_df = self.process_single_trial_action(trial=trial, action=action, context=context)
Q_values.update({action: save_df["E[y|x,D]"].values[0]})
if uncertainty_type == "epistemic":
U_values.update({action: np.sqrt(save_df["max_Ve_variance"].values[0])})
elif uncertainty_type == "total":
U_values.update({action: np.sqrt(save_df["Var[y|x,D]"].values[0])})
elif uncertainty_type == "ucb1":
UCB_uncertainty = self.UCB1_algorithm.get_uncertainty(action)
U_values.update({action: UCB_uncertainty})
UCB_values.update({action: Q_values[action] + self.config.bandit_exploration_rate * U_values[action]})
print(f"Q values: {Q_values}")
print(f"U values: {U_values}")
print(f"UCB values: {UCB_values}")
self.Q_values = Q_values
self.U_values = U_values
self.UCB_values = UCB_values
max_UCB_value = max(UCB_values.values())
max_UCB_action = [action for action, value in UCB_values.items() if value == max_UCB_value]
action_taken = self.rng.choice(max_UCB_action)
return action_taken
def single_trial(self, trial: int):
print(f"\nTrial {trial + 1}/{self.config.num_trials}")
if self.config.is_contextual_bandit:
context = self.bandit.get_next_context()
else:
context = None
is_random_action = trial < self.config.num_random_trials
action_taken = self.get_single_trial_action(trial, context, is_random_action, self.config.uncertainty_type)
reward = self.bandit.get_reward(action_taken)
regret = self.bandit.get_optimal_mean_reward() - reward
print(f"Action: {action_taken}; Reward: {reward}; Regret: {regret}")
if self.config.uncertainty_type == "ucb1":
self.UCB1_algorithm.update(action_taken, reward)
if isinstance(context, pd.Series | pd.DataFrame):
trial_df = context
else:
if self.config.is_contextual_bandit:
trial_df = pd.DataFrame({key: [value] for key, value in context.items()})
trial_df["note"] = BanditClassificationUtils.parse_features_and_action_to_note(action=action_taken, row=context, feature_columns = self.feature_columns, decimal_places=self.config.decimal_places)
else:
trial_df = pd.DataFrame({"note": [BanditClassificationUtils.parse_features_and_action_to_note(action_taken, decimal_places=self.config.decimal_places)]})
trial_df["action"] = action_taken
trial_df["label"] = reward
trial_df["regret"] = regret
trial_df["trial"] = trial
trial_df["optimal_action"] = self.bandit.optimal_action()
if not is_random_action:
for action in self.action_space:
trial_df[f"Q_value_{action}"] = self.Q_values[action]
for action in self.action_space:
trial_df[f"U_value_{action}"] = self.U_values[action]
for action in self.action_space:
trial_df[f"UCB_value_{action}"] = self.UCB_values[action]
if trial == 0:
self.D_rows = trial_df
else:
self.D_rows = pd.concat([self.D_rows, trial_df], ignore_index=True)
self.save_D_rows()
def run_experiment(self):
for trial in range(self.config.num_trials):
self.single_trial(trial)
print(f"\nTotal API calls: {self.num_api_calls}")
def main():
config = BanditClassificationExperimentConfig(**vars(args))
experiment = BanditClassificationExperiment(config)
experiment.run_experiment()
if __name__ == "__main__":
main()